# Tutorial NLG - Finetuning SU -> ID translation using IndoBART

In [1]:
import os, sys
sys.path.append('../')
os.chdir('../')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import torch
import shutil
import random
import numpy as np
import pandas as pd
from torch import optim
from transformers import MBartForConditionalGeneration

from indobenchmark import IndoNLGTokenizer
from utils.train_eval import train, evaluate
from utils.metrics import generation_metrics_fn
from utils.forward_fn import forward_generation
from utils.data_utils import MachineTranslationDataset, GenerationDataLoader

In [2]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
# Set random seed
# set_seed(26092020)

# Load Model

In [3]:
bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart-v2')
tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart-v2')

model = bart_model
model

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): Embedding(40004, 768, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): Embedding(40004, 768, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [4]:
count_param(model)

131543040

# Prepare Dataset

In [5]:
# configs and args

lr = 5e-5
gamma = 0.9
lower = True
step_size = 1
beam_size = 5
max_norm = 10
early_stop = 5

max_seq_len = 512
grad_accumulate = 1
no_special_token = False
swap_source_target = False
model_type = 'indo-bart'
valid_criterion = 'SacreBLEU'

separator_id = 4
speaker_1_id = 5
speaker_2_id = 6

train_batch_size = 8
valid_batch_size = 8
test_batch_size = 8

source_lang = "[indonesian]"
target_lang = "[sundanese]"

optimizer = optim.Adam(model.parameters(), lr=lr)
src_lid = tokenizer.special_tokens_to_ids[source_lang]
tgt_lid = tokenizer.special_tokens_to_ids[target_lang]

model.config.decoder_start_token_id = tgt_lid

# Make sure cuda is deterministic
torch.backends.cudnn.deterministic = True

# create directory
model_dir = './save/MT_SUNIBS_INZNTV/example_id_su'
if not os.path.exists(model_dir):
    os.makedirs(model_dir, exist_ok=True)

device = 'cuda0'
# set a specific cuda device
if "cuda" in device:
    torch.cuda.set_device(int(device[4:]))
    device = "cuda"
    model = model.cuda()

In [6]:
train_dataset_path = './dataset/MT_SUNIBS_INZNTV/train_preprocess.json'
valid_dataset_path = './dataset/MT_SUNIBS_INZNTV/valid_preprocess.json'
test_dataset_path = './dataset/MT_SUNIBS_INZNTV/test_preprocess.json'

train_dataset = MachineTranslationDataset(train_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
valid_dataset = MachineTranslationDataset(valid_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
test_dataset = MachineTranslationDataset(test_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)

train_loader = GenerationDataLoader(dataset=train_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=train_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=True)  
valid_loader = GenerationDataLoader(dataset=valid_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=valid_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)
test_loader = GenerationDataLoader(dataset=test_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                   batch_size=test_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)

# Test model to generate sequences

In [7]:
inputs = ['aku pergi ke toko obat membeli <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

bart_input.to(device)
bart_out = model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku pergi ke toko obat membeli<mask></s>[indonesian]
<s> aku pergi ke toko obat membeli obat.[indonesian]


In [9]:
inputs = ['kuring ka pasar senen meuli daging <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt',
                                         lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

bart_input.to(device)
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> kuring ka pasar senen meuli daging<mask></s>[sundanese]
<s> kuring ka pasar senen meuli daging sapi,[sundanese]


# Test model to translate

In [10]:
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [02:59<00:00,  1.20s/it]


In [11]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0   teu terang naon-naon, tur can tangtu simkurin...   
1               kabeh dalalahar nepi ka sareubeuhna.   
2   urang mah darma dadamelan allah, anu geus dic...   
3   anu hatena suci mah, sagala ge suci. sabalikn...   
4   leupas, petrus jeung yahya nepangan rerencang...   

                                               label  
0   sebab memang aku tidak sadar akan sesuatu, te...  
1          dan mereka semuanya makan sampai kenyang.  
2   karena kita ini buatan allah, diciptakan dala...  
3   bagi orang suci semuanya suci; tetapi bagi or...  
4   sesudah dilepaskan pergilah petrus dan yohane...  

== Model Performance ==
           BLEU  SacreBLEU    ROUGE1   ROUGE2    ROUGEL  ROUGELsum
count  1.000000   1.000000  1.000000  1.00000  1.000000   1.000000
mean   0.732548   0.742814  7.358197  1.06949  7.065098   7.063007
std         NaN        NaN       NaN      NaN       NaN        NaN
min    0.732548 

In [16]:
for i, (hyp, label) in enumerate(zip(test_hyp, test_label)):
    print(hyp, ' | ', label)
    if i == 5:
        break

 aku tidak tahu apa yang harus aku katakan, dan yang tidak mungkin aku kemukakan. hanya tuhan yang akan menentukan aku.  |   sebab memang aku tidak sadar akan sesuatu, tetapi bukan karena itulah aku dibenarkan. dia, yang menghakimi aku, ialah tuhan.
 mereka semuanya makan sampai kenyang.  |   dan mereka semuanya makan sampai kenyang.
 karena kita adalah ciptaan allah yang telah diciptakan allah di dalam kristus untuk berbuat baik, yang disediakan allah bagi kita untuk melakukan kehendak-nya.  |   karena kita ini buatan allah, diciptakan dalam kristus yesus untuk melakukan pekerjaan baik, yang dipersiapkan allah sebelumnya. ia mau, supaya kita hidup di dalamnya.
 sebab segala sesuatu kudus. tetapi orang-orang yang lemah iman tidak pernah menjadi kudus, oleh karena pikiran-pikiran yang kotor.  |   bagi orang suci semuanya suci; tetapi bagi orang najis dan bagi orang tidak beriman suatu pun tidak ada yang suci, karena baik akal maupun suara hati mereka najis.
 ketika petrus dan yohanes da

# Fine Tuning & Evaluation

In [12]:
# Train
n_epochs = 10

train(model, train_loader=train_loader, valid_loader=valid_loader, optimizer=optimizer, 
      forward_fn=forward_generation, metrics_fn=generation_metrics_fn, valid_criterion=valid_criterion, 
      tokenizer=tokenizer, n_epochs=n_epochs, evaluate_every=1, early_stop=early_stop, 
      grad_accum=grad_accumulate, step_size=step_size, gamma=gamma, 
      max_norm=max_norm, model_type=model_type, beam_size=beam_size,
      max_seq_len=max_seq_len, model_dir=model_dir, exp_id=0, fp16="", device=device)

(Epoch 1) TRAIN LOSS:2.7585 LR:0.00005000: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:32<00:00,  8.08it/s]


(Epoch 1) TRAIN LOSS:2.7585 BLEU:26.43 SacreBLEU:30.43 ROUGE1:47.31 ROUGE2:18.77 ROUGEL:42.46 ROUGELsum:42.46 LR:0.00005000


VALID LOSS:2.3251: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.84it/s]


(Epoch 1) VALID LOSS:2.3251 BLEU:15.00 SacreBLEU:15.05 ROUGE1:46.43 ROUGE2:18.12 ROUGEL:42.02 ROUGELsum:42.02


(Epoch 2) TRAIN LOSS:1.9935 LR:0.00004500: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:32<00:00,  8.09it/s]


(Epoch 2) TRAIN LOSS:1.9935 BLEU:33.04 SacreBLEU:36.85 ROUGE1:56.15 ROUGE2:27.89 ROUGEL:51.98 ROUGELsum:52.00 LR:0.00004500


VALID LOSS:2.2208: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.01it/s]


(Epoch 2) VALID LOSS:2.2208 BLEU:17.61 SacreBLEU:17.66 ROUGE1:49.02 ROUGE2:20.82 ROUGEL:45.10 ROUGELsum:45.07


(Epoch 3) TRAIN LOSS:1.5562 LR:0.00004050: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:31<00:00,  8.13it/s]


(Epoch 3) TRAIN LOSS:1.5562 BLEU:39.30 SacreBLEU:42.80 ROUGE1:62.85 ROUGE2:36.40 ROUGEL:59.59 ROUGELsum:59.58 LR:0.00004050


VALID LOSS:2.2055: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.49it/s]


(Epoch 3) VALID LOSS:2.2055 BLEU:18.93 SacreBLEU:18.97 ROUGE1:50.27 ROUGE2:22.30 ROUGEL:45.96 ROUGELsum:45.98


(Epoch 4) TRAIN LOSS:1.2110 LR:0.00003645: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:32<00:00,  8.09it/s]


(Epoch 4) TRAIN LOSS:1.2110 BLEU:46.00 SacreBLEU:49.13 ROUGE1:69.07 ROUGE2:45.06 ROUGEL:66.52 ROUGELsum:66.52 LR:0.00003645


VALID LOSS:2.2276: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.94it/s]


(Epoch 4) VALID LOSS:2.2276 BLEU:20.18 SacreBLEU:20.19 ROUGE1:51.12 ROUGE2:23.53 ROUGEL:47.36 ROUGELsum:47.34


(Epoch 5) TRAIN LOSS:0.9340 LR:0.00003281: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:33<00:00,  7.96it/s]


(Epoch 5) TRAIN LOSS:0.9340 BLEU:53.62 SacreBLEU:56.32 ROUGE1:75.09 ROUGE2:54.33 ROUGEL:73.29 ROUGELsum:73.28 LR:0.00003281


VALID LOSS:2.2741: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.61it/s]


(Epoch 5) VALID LOSS:2.2741 BLEU:20.64 SacreBLEU:20.67 ROUGE1:51.24 ROUGE2:24.08 ROUGEL:47.48 ROUGELsum:47.48


(Epoch 6) TRAIN LOSS:0.7234 LR:0.00002952: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:32<00:00,  8.09it/s]


(Epoch 6) TRAIN LOSS:0.7234 BLEU:61.54 SacreBLEU:63.80 ROUGE1:80.34 ROUGE2:63.37 ROUGEL:79.12 ROUGELsum:79.11 LR:0.00002952


VALID LOSS:2.3189: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.92it/s]


(Epoch 6) VALID LOSS:2.3189 BLEU:20.47 SacreBLEU:20.48 ROUGE1:50.96 ROUGE2:23.93 ROUGEL:47.30 ROUGELsum:47.34
count stop: 1


(Epoch 7) TRAIN LOSS:0.5727 LR:0.00002657: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:32<00:00,  8.10it/s]


(Epoch 7) TRAIN LOSS:0.5727 BLEU:68.59 SacreBLEU:70.43 ROUGE1:84.45 ROUGE2:70.76 ROUGEL:83.58 ROUGELsum:83.58 LR:0.00002657


VALID LOSS:2.3871: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.09it/s]


(Epoch 7) VALID LOSS:2.3871 BLEU:21.31 SacreBLEU:21.34 ROUGE1:51.37 ROUGE2:24.59 ROUGEL:47.70 ROUGELsum:47.67


(Epoch 8) TRAIN LOSS:0.4506 LR:0.00002391: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:31<00:00,  8.18it/s]


(Epoch 8) TRAIN LOSS:0.4506 BLEU:75.07 SacreBLEU:76.55 ROUGE1:87.83 ROUGE2:77.28 ROUGEL:87.31 ROUGELsum:87.31 LR:0.00002391


VALID LOSS:2.4403: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.22it/s]


(Epoch 8) VALID LOSS:2.4403 BLEU:21.32 SacreBLEU:21.34 ROUGE1:51.13 ROUGE2:24.34 ROUGEL:47.41 ROUGELsum:47.43
count stop: 1


(Epoch 9) TRAIN LOSS:0.3654 LR:0.00002152: 100%|████████████████████████████████████████████████████████████████████████| 746/746 [01:32<00:00,  8.11it/s]


(Epoch 9) TRAIN LOSS:0.3654 BLEU:80.11 SacreBLEU:81.30 ROUGE1:90.34 ROUGE2:82.13 ROUGEL:90.02 ROUGELsum:90.02 LR:0.00002152


VALID LOSS:2.5097: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.00it/s]


(Epoch 9) VALID LOSS:2.5097 BLEU:21.41 SacreBLEU:21.43 ROUGE1:51.35 ROUGE2:24.61 ROUGEL:47.63 ROUGELsum:47.66


(Epoch 10) TRAIN LOSS:0.2982 LR:0.00001937: 100%|███████████████████████████████████████████████████████████████████████| 746/746 [01:31<00:00,  8.13it/s]


(Epoch 10) TRAIN LOSS:0.2982 BLEU:84.24 SacreBLEU:85.17 ROUGE1:92.29 ROUGE2:85.85 ROUGEL:92.05 ROUGELsum:92.05 LR:0.00001937


VALID LOSS:2.5544: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.95it/s]


(Epoch 10) VALID LOSS:2.5544 BLEU:21.39 SacreBLEU:21.42 ROUGE1:50.89 ROUGE2:24.54 ROUGEL:47.36 ROUGELsum:47.32
count stop: 1


In [13]:
# Load best model
model.load_state_dict(torch.load(model_dir + "/best_model_0.th"))

<All keys matched successfully>

In [14]:
# Evaluate
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:49<00:00,  1.37it/s]


In [15]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0   aku tidak tahu apa yang harus aku katakan, da...   
1              mereka semuanya makan sampai kenyang.   
2   karena kita adalah ciptaan allah yang telah d...   
3   sebab segala sesuatu kudus. tetapi orang-oran...   
4   ketika petrus dan yohanes datang bersama-sama...   

                                               label  
0   sebab memang aku tidak sadar akan sesuatu, te...  
1          dan mereka semuanya makan sampai kenyang.  
2   karena kita ini buatan allah, diciptakan dala...  
3   bagi orang suci semuanya suci; tetapi bagi or...  
4   sesudah dilepaskan pergilah petrus dan yohane...  

== Model Performance ==
            BLEU  SacreBLEU     ROUGE1     ROUGE2     ROUGEL  ROUGELsum
count   1.000000   1.000000   1.000000   1.000000   1.000000   1.000000
mean   18.265366  18.264306  47.181258  22.688521  41.942377  41.903377
std          NaN        NaN        NaN        NaN        NaN        

# Results on other MT Tasks

<img src="indonlg.png"/>

# Results on extremely low-resource MT Tasks

<img src="low_resource_mt.png"/>