# Finetuning ID -> SU using IndoBART

In [1]:
import os, sys
sys.path.append('../')
os.chdir('../')

import torch
import shutil
import random
import numpy as np
import pandas as pd
from torch import optim
from transformers import MBartForConditionalGeneration

from indobenchmark import IndoNLGTokenizer
from utils.train_eval import train, evaluate
from utils.metrics import generation_metrics_fn
from utils.forward_fn import forward_generation
from utils.data_utils import MachineTranslationDataset, GenerationDataLoader

In [2]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
# Set random seed
# set_seed(26092020)

# Load Model

In [3]:
bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart')
tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart')

# bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart-v2')
# tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart-v2')

model = bart_model
model

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): Embedding(40004, 768, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): Embedding(40004, 768, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 768, padding_idx=1)
      (layers): ModuleList(
        (0): MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=Tru

In [4]:
count_param(model)

131543040

# Prepare Dataset

In [5]:
# configs and args

lr = 1e-4
gamma = 0.9
lower = True
step_size = 1
beam_size = 5
max_norm = 10
early_stop = 5

max_seq_len = 512
grad_accumulate = 1
no_special_token = False
swap_source_target = True
model_type = 'indo-bart'
valid_criterion = 'SacreBLEU'

separator_id = 4
speaker_1_id = 5
speaker_2_id = 6

train_batch_size = 8
valid_batch_size = 8
test_batch_size = 8

source_lang = "[indonesian]"
target_lang = "[sundanese]"

optimizer = optim.Adam(model.parameters(), lr=lr)
src_lid = tokenizer.special_tokens_to_ids[source_lang]
tgt_lid = tokenizer.special_tokens_to_ids[target_lang]

model.config.decoder_start_token_id = tgt_lid

# Make sure cuda is deterministic
torch.backends.cudnn.deterministic = True

# create directory
model_dir = './save/MT_SUNIBS_INZNTV/example_id_su'
if not os.path.exists(model_dir):
    os.makedirs(model_dir, exist_ok=True)

device = 'cuda0'
# set a specific cuda device
if "cuda" in device:
    torch.cuda.set_device(int(device[4:]))
    device = "cuda"
    model = model.cuda()

In [6]:
train_dataset_path = './dataset/MT_SUNIBS_INZNTV/train_preprocess.json'
valid_dataset_path = './dataset/MT_SUNIBS_INZNTV/valid_preprocess.json'
test_dataset_path = './dataset/MT_SUNIBS_INZNTV/test_preprocess.json'

train_dataset = MachineTranslationDataset(train_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
valid_dataset = MachineTranslationDataset(valid_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
test_dataset = MachineTranslationDataset(test_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)

train_loader = GenerationDataLoader(dataset=train_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=train_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=True)  
valid_loader = GenerationDataLoader(dataset=valid_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=valid_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)
test_loader = GenerationDataLoader(dataset=test_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                   batch_size=test_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)

# Test model to generate sequences

In [7]:
inputs = ['aku pergi ke toko obat membeli <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

bart_input.to(device)
bart_out = model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

2021-10-24 17:05:52.724648: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


<s> aku pergi ke toko obat membeli <mask> [indonesian]
<s> aku pergi ke toko obat membeli obat.


In [8]:
inputs = ['aku menyang pasar karo <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[javanese]', decoder_lang_token='[javanese]')

bart_input.to(device)
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku menyang pasar karo <mask> [javanese]
<s> aku menyang pasar karo tuku </s>


In [9]:
inputs = ['kuring ka pasar senen meuli daging <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

bart_input.to(device)
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> kuring ka pasar senen meuli daging <mask> [sundanese]
<s> kuring ka pasar senen meuli daging sapi </s>


# Test model to translate

In [10]:
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:30<00:00,  1.65it/s]


In [11]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0  memang aku tidak sadar akan sesuatu, tetapi bu...   
1              mereka semuanya makan sampai kenyang.   
2  kita ini buatan allah, diciptakan dalam kristu...   
3  orang suci semuanya suci ; tetapi bagi orang n...   
4  maka pergilah petrus dan yohanes kepada teman ...   

                                               label  
0  da teu terang naon - na on, tur can tangtu sim...  
1    terus kabeh dal alah ar nepi ka sare ub euh na.  
2  sabab urang mah darma dad am elan allah, anu g...  
3  pikeun anu hat ena suci mah, sagala ge suci. s...  
4  sanggeus leupas, petrus jeung yahya nep angan ...  

== Model Performance ==
           BLEU  SacreBLEU    ROUGE1    ROUGE2    ROUGEL  ROUGELsum
count  1.000000   1.000000  1.000000  1.000000  1.000000    1.00000
mean   1.853613   1.876446  8.022539  1.906093  7.640285    7.64969
std         NaN        NaN       NaN       NaN       NaN        NaN
min    1.853

# Fine Tuning & Evaluation

In [12]:
# Train

n_epochs = 10

train(model, train_loader=train_loader, valid_loader=valid_loader, optimizer=optimizer, 
      forward_fn=forward_generation, metrics_fn=generation_metrics_fn, valid_criterion=valid_criterion, 
      tokenizer=tokenizer, n_epochs=n_epochs, evaluate_every=1, early_stop=early_stop, 
      grad_accum=grad_accumulate, step_size=step_size, gamma=gamma, 
      max_norm=max_norm, model_type=model_type, beam_size=beam_size,
      max_seq_len=max_seq_len, model_dir=model_dir, exp_id=0, fp16="", device=device)

(Epoch 1) TRAIN LOSS:3.4359 LR:0.00010000: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.44it/s]


(Epoch 1) TRAIN LOSS:3.4359 BLEU:26.00 SacreBLEU:27.76 ROUGE1:43.82 ROUGE2:16.77 ROUGEL:39.59 ROUGELsum:39.60 LR:0.00010000


VALID LOSS:2.9707: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 30.13it/s]


(Epoch 1) VALID LOSS:2.9707 BLEU:18.26 SacreBLEU:18.41 ROUGE1:44.99 ROUGE2:16.78 ROUGEL:40.54 ROUGELsum:40.54


(Epoch 2) TRAIN LOSS:2.5026 LR:0.00009000: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.45it/s]


(Epoch 2) TRAIN LOSS:2.5026 BLEU:32.25 SacreBLEU:34.05 ROUGE1:53.80 ROUGE2:25.47 ROUGEL:50.03 ROUGELsum:50.04 LR:0.00009000


VALID LOSS:2.8121: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 30.12it/s]


(Epoch 2) VALID LOSS:2.8121 BLEU:21.00 SacreBLEU:21.10 ROUGE1:47.60 ROUGE2:20.07 ROUGEL:43.41 ROUGELsum:43.43


(Epoch 3) TRAIN LOSS:1.9167 LR:0.00008100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.44it/s]


(Epoch 3) TRAIN LOSS:1.9167 BLEU:38.30 SacreBLEU:39.95 ROUGE1:61.35 ROUGE2:33.81 ROUGEL:58.19 ROUGELsum:58.19 LR:0.00008100


VALID LOSS:2.7939: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.78it/s]


(Epoch 3) VALID LOSS:2.7939 BLEU:21.97 SacreBLEU:22.10 ROUGE1:48.85 ROUGE2:21.40 ROUGEL:44.71 ROUGELsum:44.70


(Epoch 4) TRAIN LOSS:1.4321 LR:0.00007290: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.44it/s]


(Epoch 4) TRAIN LOSS:1.4321 BLEU:45.91 SacreBLEU:47.36 ROUGE1:68.94 ROUGE2:43.75 ROUGEL:66.50 ROUGELsum:66.50 LR:0.00007290


VALID LOSS:2.8450: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.66it/s]


(Epoch 4) VALID LOSS:2.8450 BLEU:22.63 SacreBLEU:22.74 ROUGE1:49.29 ROUGE2:21.90 ROUGEL:45.29 ROUGELsum:45.29


(Epoch 5) TRAIN LOSS:1.0360 LR:0.00006561: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.44it/s]


(Epoch 5) TRAIN LOSS:1.0360 BLEU:55.46 SacreBLEU:56.64 ROUGE1:76.38 ROUGE2:55.30 ROUGEL:74.71 ROUGELsum:74.69 LR:0.00006561


VALID LOSS:2.9274: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 30.17it/s]


(Epoch 5) VALID LOSS:2.9274 BLEU:22.38 SacreBLEU:22.46 ROUGE1:49.10 ROUGE2:21.79 ROUGEL:45.20 ROUGELsum:45.23
count stop: 1


(Epoch 6) TRAIN LOSS:0.7306 LR:0.00005905: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.39it/s]


(Epoch 6) TRAIN LOSS:0.7306 BLEU:65.92 SacreBLEU:66.79 ROUGE1:82.99 ROUGE2:67.11 ROUGEL:81.99 ROUGELsum:81.99 LR:0.00005905


VALID LOSS:3.0414: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.91it/s]


(Epoch 6) VALID LOSS:3.0414 BLEU:22.77 SacreBLEU:22.89 ROUGE1:49.43 ROUGE2:22.37 ROUGEL:45.59 ROUGELsum:45.61


(Epoch 7) TRAIN LOSS:0.5174 LR:0.00005314: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.39it/s]


(Epoch 7) TRAIN LOSS:0.5174 BLEU:75.09 SacreBLEU:75.72 ROUGE1:88.08 ROUGE2:76.75 ROUGEL:87.53 ROUGELsum:87.52 LR:0.00005314


VALID LOSS:3.1377: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.88it/s]


(Epoch 7) VALID LOSS:3.1377 BLEU:22.95 SacreBLEU:23.07 ROUGE1:49.38 ROUGE2:22.52 ROUGEL:45.60 ROUGELsum:45.60


(Epoch 8) TRAIN LOSS:0.3631 LR:0.00004783: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.39it/s]


(Epoch 8) TRAIN LOSS:0.3631 BLEU:82.62 SacreBLEU:83.06 ROUGE1:91.77 ROUGE2:84.02 ROUGEL:91.51 ROUGELsum:91.52 LR:0.00004783


VALID LOSS:3.2386: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.99it/s]


(Epoch 8) VALID LOSS:3.2386 BLEU:23.14 SacreBLEU:23.23 ROUGE1:49.33 ROUGE2:22.64 ROUGEL:45.68 ROUGELsum:45.65


(Epoch 9) TRAIN LOSS:0.2644 LR:0.00004305: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.43it/s]


(Epoch 9) TRAIN LOSS:0.2644 BLEU:87.60 SacreBLEU:87.92 ROUGE1:94.25 ROUGE2:88.85 ROUGEL:94.10 ROUGELsum:94.11 LR:0.00004305


VALID LOSS:3.3427: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.91it/s]


(Epoch 9) VALID LOSS:3.3427 BLEU:23.20 SacreBLEU:23.30 ROUGE1:49.36 ROUGE2:22.49 ROUGEL:45.61 ROUGELsum:45.60


(Epoch 10) TRAIN LOSS:0.1967 LR:0.00003874: 100%|████████████████████████████████████████████████████████████████████████████████████| 746/746 [01:28<00:00,  8.45it/s]


(Epoch 10) TRAIN LOSS:0.1967 BLEU:91.17 SacreBLEU:91.41 ROUGE1:95.92 ROUGE2:92.09 ROUGEL:95.85 ROUGELsum:95.85 LR:0.00003874


VALID LOSS:3.4304: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.84it/s]


(Epoch 10) VALID LOSS:3.4304 BLEU:23.17 SacreBLEU:23.27 ROUGE1:49.44 ROUGE2:22.61 ROUGEL:45.77 ROUGELsum:45.78
count stop: 1


In [13]:
# Load best model
model.load_state_dict(torch.load(model_dir + "/best_model_0.th"))

<All keys matched successfully>

In [14]:
# Evaluate
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:37<00:00,  1.55it/s]


In [15]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0  memang pam oh alan, ku simkuring teu sadar. ta...   
1    kabeh oge, dal alah ar nepi ka sare ub euh eun.   
2  sabab urang teh, mahluk allah anu diciptakan k...   
3  sabalikna, jelema - j el ema anu suci mah, kab...   
4  ti dinya, petrus jeung yahya dib e ungk eut ke...   

                                               label  
0  da teu terang naon - na on, tur can tangtu sim...  
1    terus kabeh dal alah ar nepi ka sare ub euh na.  
2  sabab urang mah darma dad am elan allah, anu g...  
3  pikeun anu hat ena suci mah, sagala ge suci. s...  
4  sanggeus leupas, petrus jeung yahya nep angan ...  

== Model Performance ==
            BLEU  SacreBLEU     ROUGE1     ROUGE2    ROUGEL  ROUGELsum
count   1.000000   1.000000   1.000000   1.000000   1.00000   1.000000
mean   15.366755  15.356161  36.183704  16.636455  31.35292  31.363771
std          NaN        NaN        NaN        NaN       NaN        NaN
