# Finetuning ID -> SU using IndoGPT

In [1]:
import os, sys
sys.path.append('../')
os.chdir('../')

import torch
import shutil
import random
import numpy as np
import pandas as pd
from torch import optim
from transformers import GPT2LMHeadModel

from indobenchmark import IndoNLGTokenizer
from utils.train_eval import train, evaluate
from utils.metrics import generation_metrics_fn
from utils.forward_fn import forward_generation
from utils.data_utils import MachineTranslationDataset, GenerationDataLoader

In [2]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
# Set random seed
set_seed(26092020)

# Load Model

In [3]:
gpt_model = GPT2LMHeadModel.from_pretrained('indobenchmark/indogpt')
tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indogpt')
model = gpt_model
model

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(40005, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): Laye

In [4]:
count_param(model)

116566272

# Prepare Dataset

In [5]:
# configs and args

lr = 1e-4
gamma = 0.9
lower = True
step_size = 1
beam_size = 5
max_norm = 10
early_stop = 5

max_seq_len = 512
grad_accumulate = 1
no_special_token = False
swap_source_target = True
model_type = 'indo-gpt2'
valid_criterion = 'SacreBLEU'

separator_id = 4
speaker_1_id = 5
speaker_2_id = 6

train_batch_size = 8
valid_batch_size = 8
test_batch_size = 8

source_lang = "[indonesian]"
target_lang = "[sundanese]"

optimizer = optim.Adam(model.parameters(), lr=lr)
src_lid = tokenizer.special_tokens_to_ids[source_lang]
tgt_lid = tokenizer.special_tokens_to_ids[target_lang]

model.config.decoder_start_token_id = tgt_lid

# Make sure cuda is deterministic
torch.backends.cudnn.deterministic = True

# create directory
model_dir = './save/MT_SUNIBS_INZNTV/example_id_su'
if not os.path.exists(model_dir):
    os.makedirs(model_dir, exist_ok=True)

In [6]:
train_dataset_path = './dataset/MT_SUNIBS_INZNTV/train_preprocess.json'
valid_dataset_path = './dataset/MT_SUNIBS_INZNTV/valid_preprocess.json'
test_dataset_path = './dataset/MT_SUNIBS_INZNTV/test_preprocess.json'

train_dataset = MachineTranslationDataset(train_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
valid_dataset = MachineTranslationDataset(valid_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)
test_dataset = MachineTranslationDataset(test_dataset_path, tokenizer, lowercase=lower, no_special_token=no_special_token, 
                                            speaker_1_id=speaker_1_id, speaker_2_id=speaker_2_id, separator_id=separator_id,
                                            max_token_length=max_seq_len, swap_source_target=swap_source_target)

train_loader = GenerationDataLoader(dataset=train_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=train_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=True)  
valid_loader = GenerationDataLoader(dataset=valid_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                    batch_size=valid_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)
test_loader = GenerationDataLoader(dataset=test_dataset, model_type=model_type, tokenizer=tokenizer, max_seq_len=max_seq_len, 
                                   batch_size=test_batch_size, src_lid_token_id=src_lid, tgt_lid_token_id=tgt_lid, num_workers=8, shuffle=False)

# Test model to generate sequences

In [7]:
gpt_input = torch.LongTensor([tokenizer.encode('<s> aku adalah anak', add_special_tokens=False)])
gpt_out = gpt_model.generate(gpt_input)
tokenizer.decode(gpt_out[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
2021-10-24 22:04:41.322657: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


'<s> aku adalah anak pertama dari tiga bersaudara. </s> aku lahir di kota kecil yang sama dengan ayahku.'

In [8]:
gpt_input = torch.LongTensor([tokenizer.encode('<s> hai, bagaimana kabar', add_special_tokens=False)])
gpt_out = gpt_model.generate(gpt_input)
tokenizer.decode(gpt_out[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'<s> hai, bagaimana kabar kalian? semoga sehat selalu ya. kali ini saya akan membahas tentang cara membuat'

# Test model to translate

In [9]:
device = 'cuda1'
# set a specific cuda device
if "cuda" in device:
    torch.cuda.set_device(int(device[4:]))
    device = "cuda"
    model = model.cuda()

In [10]:
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [20:23<00:00,  8.16s/it]


In [11]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0  dia yang menghakimi aku, dialah yang menghakim...   
1  <0xE2> <0x80> <0x98> um ar berkata :  <0xE2> <...   
2                               yesus di kayu salib.   
3     tidak ada yang najis dan tidak ada yang najis.   
4  pada waktu itu mereka berkata kepada mereka : ...   

                                               label  
0  da teu terang naon - na on, tur can tangtu sim...  
1    terus kabeh dal alah ar nepi ka sare ub euh na.  
2  sabab urang mah darma dad am elan allah, anu g...  
3  pikeun anu hat ena suci mah, sagala ge suci. s...  
4  sanggeus leupas, petrus jeung yahya nep angan ...  

== Model Performance ==
           BLEU  SacreBLEU    ROUGE1   ROUGE2    ROUGEL  ROUGELsum
count  1.000000    1.00000  1.000000  1.00000  1.000000   1.000000
mean   0.652959    0.66137  2.338202  0.54499  2.260256   2.255067
std         NaN        NaN       NaN      NaN       NaN        NaN
min    0.652959 

# Fine Tuning & Evaluation

In [12]:
# Train

n_epochs = 10

train(model, train_loader=train_loader, valid_loader=valid_loader, optimizer=optimizer, 
      forward_fn=forward_generation, metrics_fn=generation_metrics_fn, valid_criterion=valid_criterion, 
      tokenizer=tokenizer, n_epochs=n_epochs, evaluate_every=1, early_stop=early_stop, 
      grad_accum=grad_accumulate, step_size=step_size, gamma=gamma, 
      max_norm=max_norm, model_type=model_type, beam_size=beam_size,
      max_seq_len=max_seq_len, model_dir=model_dir, exp_id=0, fp16="", device=device)

(Epoch 1) TRAIN LOSS:3.5273 LR:0.00010000: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:04<00:00,  5.99it/s]


(Epoch 1) TRAIN LOSS:3.5273 BLEU:19.00 SacreBLEU:21.37 ROUGE1:42.41 ROUGE2:13.53 ROUGEL:37.20 ROUGELsum:37.19 LR:0.00010000


VALID LOSS:3.1088: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.14it/s]


(Epoch 1) VALID LOSS:3.1088 BLEU:17.96 SacreBLEU:18.10 ROUGE1:44.66 ROUGE2:16.55 ROUGEL:39.71 ROUGELsum:39.71


(Epoch 2) TRAIN LOSS:2.5424 LR:0.00009000: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:04<00:00,  6.01it/s]


(Epoch 2) TRAIN LOSS:2.5424 BLEU:26.00 SacreBLEU:28.14 ROUGE1:52.66 ROUGE2:22.86 ROUGEL:48.17 ROUGELsum:48.17 LR:0.00009000


VALID LOSS:2.9379: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.83it/s]


(Epoch 2) VALID LOSS:2.9379 BLEU:19.93 SacreBLEU:20.07 ROUGE1:47.16 ROUGE2:18.68 ROUGEL:42.42 ROUGELsum:42.43


(Epoch 3) TRAIN LOSS:1.8770 LR:0.00008100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:03<00:00,  6.02it/s]


(Epoch 3) TRAIN LOSS:1.8770 BLEU:33.40 SacreBLEU:35.24 ROUGE1:61.10 ROUGE2:32.61 ROUGEL:57.53 ROUGELsum:57.54 LR:0.00008100


VALID LOSS:2.9793: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.84it/s]


(Epoch 3) VALID LOSS:2.9793 BLEU:21.12 SacreBLEU:21.23 ROUGE1:48.42 ROUGE2:20.52 ROUGEL:43.66 ROUGELsum:43.66


(Epoch 4) TRAIN LOSS:1.3120 LR:0.00007290: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:04<00:00,  6.00it/s]


(Epoch 4) TRAIN LOSS:1.3120 BLEU:43.94 SacreBLEU:45.38 ROUGE1:70.35 ROUGE2:45.62 ROUGEL:67.75 ROUGELsum:67.75 LR:0.00007290


VALID LOSS:3.0966: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.96it/s]


(Epoch 4) VALID LOSS:3.0966 BLEU:21.41 SacreBLEU:21.53 ROUGE1:48.41 ROUGE2:20.64 ROUGEL:43.86 ROUGELsum:43.86


(Epoch 5) TRAIN LOSS:0.8570 LR:0.00006561: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:03<00:00,  6.03it/s]


(Epoch 5) TRAIN LOSS:0.8570 BLEU:56.92 SacreBLEU:57.93 ROUGE1:78.75 ROUGE2:60.04 ROUGEL:77.19 ROUGELsum:77.18 LR:0.00006561


VALID LOSS:3.2108: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.84it/s]


(Epoch 5) VALID LOSS:3.2108 BLEU:21.70 SacreBLEU:21.83 ROUGE1:48.76 ROUGE2:20.79 ROUGEL:44.10 ROUGELsum:44.06


(Epoch 6) TRAIN LOSS:0.5415 LR:0.00005905: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:03<00:00,  6.06it/s]


(Epoch 6) TRAIN LOSS:0.5415 BLEU:70.28 SacreBLEU:70.89 ROUGE1:85.70 ROUGE2:73.27 ROUGEL:84.81 ROUGELsum:84.81 LR:0.00005905


VALID LOSS:3.3348: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.81it/s]


(Epoch 6) VALID LOSS:3.3348 BLEU:21.75 SacreBLEU:21.86 ROUGE1:48.67 ROUGE2:21.08 ROUGEL:44.30 ROUGELsum:44.27


(Epoch 7) TRAIN LOSS:0.3411 LR:0.00005314: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:02<00:00,  6.08it/s]


(Epoch 7) TRAIN LOSS:0.3411 BLEU:79.68 SacreBLEU:80.01 ROUGE1:89.97 ROUGE2:82.11 ROUGEL:89.48 ROUGELsum:89.48 LR:0.00005314


VALID LOSS:3.4560: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.34it/s]


(Epoch 7) VALID LOSS:3.4560 BLEU:21.79 SacreBLEU:21.90 ROUGE1:48.57 ROUGE2:20.96 ROUGEL:44.16 ROUGELsum:44.17


(Epoch 8) TRAIN LOSS:0.2231 LR:0.00004783: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:03<00:00,  6.06it/s]


(Epoch 8) TRAIN LOSS:0.2231 BLEU:86.19 SacreBLEU:86.36 ROUGE1:92.60 ROUGE2:87.59 ROUGEL:92.31 ROUGELsum:92.31 LR:0.00004783


VALID LOSS:3.5459: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.21it/s]


(Epoch 8) VALID LOSS:3.5459 BLEU:21.89 SacreBLEU:22.02 ROUGE1:48.66 ROUGE2:21.34 ROUGEL:44.53 ROUGELsum:44.51


(Epoch 9) TRAIN LOSS:0.1579 LR:0.00004305: 100%|█████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:02<00:00,  6.08it/s]


(Epoch 9) TRAIN LOSS:0.1579 BLEU:89.63 SacreBLEU:89.71 ROUGE1:93.97 ROUGE2:90.54 ROUGEL:93.77 ROUGELsum:93.77 LR:0.00004305


VALID LOSS:3.6213: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.04it/s]


(Epoch 9) VALID LOSS:3.6213 BLEU:22.01 SacreBLEU:22.14 ROUGE1:48.63 ROUGE2:21.30 ROUGEL:44.24 ROUGELsum:44.19


(Epoch 10) TRAIN LOSS:0.1190 LR:0.00003874: 100%|████████████████████████████████████████████████████████████████████████████████████| 746/746 [02:04<00:00,  6.00it/s]


(Epoch 10) TRAIN LOSS:0.1190 BLEU:91.57 SacreBLEU:91.62 ROUGE1:94.62 ROUGE2:91.91 ROUGEL:94.43 ROUGELsum:94.42 LR:0.00003874


VALID LOSS:3.7146: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.36it/s]


(Epoch 10) VALID LOSS:3.7146 BLEU:22.11 SacreBLEU:22.25 ROUGE1:49.02 ROUGE2:21.73 ROUGEL:44.67 ROUGELsum:44.67


In [13]:
# Load best model
model.load_state_dict(torch.load(model_dir + "/best_model_0.th"))

<All keys matched successfully>

In [14]:
# Evaluate
test_loss, test_metrics, test_hyp, test_label = evaluate(model, data_loader=test_loader, forward_fn=forward_generation, 
                                                         metrics_fn=generation_metrics_fn, model_type=model_type, 
                                                         tokenizer=tokenizer, beam_size=beam_size, 
                                                         max_seq_len=max_seq_len, is_test=True, 
                                                         device='cuda')

TESTING... : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [02:26<00:00,  1.02it/s]


In [15]:
metrics_scores = []
result_dfs = []

metrics_scores.append(test_metrics)
result_dfs.append(pd.DataFrame({
    'hyp': test_hyp, 
    'label': test_label
}))

result_df = pd.concat(result_dfs)
metric_df = pd.DataFrame.from_records(metrics_scores)

print('== Prediction Result ==')
print(result_df.head())
print()

print('== Model Performance ==')
print(metric_df.describe())

result_df.to_csv(model_dir + "/prediction_result.csv")
metric_df.describe().to_csv(model_dir + "/evaluation_result.csv")

== Prediction Result ==
                                                 hyp  \
0  tur simkuring mah teu sadar kana hal eta, tapi...   
1  tuluy jelema - j el ema teh dal alah ar s ase ...   
2  sabab urang oge beunang allah diciptakan ku ke...   
3  anu suci mah, sagala rupa oge, suci. tapi pike...   
4  ti dinya, petrus jeung yahya nga bu j eng ka m...   

                                               label  
0  da teu terang naon - na on, tur can tangtu sim...  
1    terus kabeh dal alah ar nepi ka sare ub euh na.  
2  sabab urang mah darma dad am elan allah, anu g...  
3  pikeun anu hat ena suci mah, sagala ge suci. s...  
4  sanggeus leupas, petrus jeung yahya nep angan ...  

== Model Performance ==
          BLEU  SacreBLEU     ROUGE1     ROUGE2     ROUGEL  ROUGELsum
count   1.0000   1.000000   1.000000   1.000000   1.000000    1.00000
mean   14.5136  14.502339  32.815334  15.205863  28.659478   28.61727
std        NaN        NaN        NaN        NaN        NaN        NaN
min 