## Install the packages

In [None]:
!pip install -Uqq datasets pythainlp==2.2.4 transformers==4.4.0 tensorflow==2.4.0 tensorflow_text emoji seqeval sentencepiece fuzzywuzzy
!npx degit --force https://github.com/vistec-AI/thai2transformers#dev

In [None]:
%load_ext autoreload
%autoreload 2

import pythainlp, transformers
pythainlp.__version__, transformers.__version__ #fix pythainlp to stabilize word tokenization for metrics

In [None]:
import collections
import logging
import pprint
import re
from tqdm.auto import tqdm

import numpy as np
import torch

#datasets
from datasets import (
    load_dataset, 
    load_metric, 
    concatenate_datasets,
    load_from_disk,
)

#transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    TrainingArguments,
    Trainer,
    default_data_collator,
)

#thai2transformers
import thai2transformers
from thai2transformers.metrics import (
    squad_newmm_metric,
    question_answering_metrics,
)
from thai2transformers.preprocess import (
    prepare_qa_train_features
)
from thai2transformers.tokenizers import (
    ThaiRobertaTokenizer,
    ThaiWordsNewmmTokenizer,
    ThaiWordsSyllableTokenizer,
    FakeSefrCutTokenizer,
    SEFR_SPLIT_TOKEN
)

from tqdm import tqdm

In [None]:
model_names = [
    'wangchanberta-base-att-spm-uncased',
    'xlm-roberta-base',
    'bert-base-multilingual-cased',
    'wangchanberta-base-wiki-newmm',
    'wangchanberta-base-wiki-ssg',
    'wangchanberta-base-wiki-sefr',
    'wangchanberta-base-wiki-spm',
]

tokenizers = {
    'wangchanberta-base-att-spm-uncased': AutoTokenizer,
    'xlm-roberta-base': AutoTokenizer,
    'bert-base-multilingual-cased': AutoTokenizer,
    'wangchanberta-base-wiki-newmm': ThaiWordsNewmmTokenizer,
    'wangchanberta-base-wiki-ssg': ThaiWordsSyllableTokenizer,
    'wangchanberta-base-wiki-sefr': FakeSefrCutTokenizer,
    'wangchanberta-base-wiki-spm': ThaiRobertaTokenizer,
}
public_models = ['xlm-roberta-base', 'bert-base-multilingual-cased'] 
#@title Choose Pretrained Model
model_name = "wangchanberta-base-att-spm-uncased" #@param ["wangchanberta-base-att-spm-uncased", "xlm-roberta-base", "bert-base-multilingual-cased", "wangchanberta-base-wiki-newmm", "wangchanberta-base-wiki-syllable", "wangchanberta-base-wiki-sefr", "wangchanberta-base-wiki-spm"]

#create tokenizer
tokenizer = tokenizers[model_name].from_pretrained(
                f'airesearch/{model_name}' if model_name not in public_models else f'{model_name}',
                revision='main',
                model_max_length=416,)

## Prepare function for calculate metrics

In [None]:
!pip install rouge

In [None]:
from rouge import Rouge 
rouge = Rouge()
def cal_rouge_score(hyps, refs, get_average_f1=True):
  '''
  argument: cands, refs [list of string], get_average_f1=True
  return dict of r1, r2, rl score
  if get_average_f1 == True return mean of rouge-1, rouge-2, rouge-L
  '''
  r1 = dict(); r1['precision'] = []; r1['recall'] = []; r1['f1'] = []
  r2 = dict(); r2['precision'] = []; r2['recall'] = []; r2['f1'] = []
  rl = dict(); rl['precision'] = []; rl['recall'] = []; rl['f1'] = []
  for hyp, ref in zip(hyps, refs):
    score = {}
    if(len(hyp)==0 or len(ref)==0):
      score = {
          'rouge-1': {
              'p': 0,
              'r': 0,
              'f': 0
          },
          'rouge-2': {
              'p': 0,
              'r': 0,
              'f': 0
          },
          'rouge-l': {
              'p': 0,
              'r': 0,
              'f': 0
          }
      }
    else: score = rouge.get_scores(hyp, ref)[0]
    r1['precision'].append(score['rouge-1']['p'])
    r1['recall'].append(score['rouge-1']['r'])
    r1['f1'].append(score['rouge-1']['f'])
    
    r2['precision'].append(score['rouge-2']['f'])
    r2['recall'].append(score['rouge-2']['f'])
    r2['f1'].append(score['rouge-2']['f'])

    rl['precision'].append(score['rouge-l']['f'])
    rl['recall'].append(score['rouge-l']['f'])
    rl['f1'].append(score['rouge-l']['f'])
  if(get_average_f1==True): return sum(r1['f1'])/len(r1['f1']), sum(r2['f1'])/len(r2['f1']), sum(rl['f1'])/len(rl['f1'])
  else: return r1, r2, rl

In [None]:
cands = ['test test test test test test bad']
refs = ['test test']

r1, r2, rl = cal_rouge_score(cands, refs)
print(r1)
print(r2)
print(rl)

## Utility functions for calculate label in our use.

In [None]:
def tokenize_with_space(texts, tokenizer):
  output = []
  encoded_texts = tokenizer(texts, max_length=416, truncation=True)
  for text in encoded_texts['input_ids']:
    tokenized_text = " ".join(tokenizer.convert_ids_to_tokens(text, skip_special_tokens=True))
    if(len(tokenized_text)==0):
      output.append("")
      continue
    if(tokenized_text[0]=="▁"): 
      tokenized_text = tokenized_text[1:]
    output.append(tokenized_text.strip())
  return output


def selection_start_end(paragraphs_raw, summaries_raw, tokenizer, length_sum_max = 10, metric='rouge-l'):
  """
  Select the start position and end postion for each paragraph to make a summary and maximize the Rouge-L score
  Args: 
  paragraphs [#number of paragraph, #number of word, #number of character] (must be tokenized with space and space change to '_')
  summaries [#number of summary, #number of word, #number of character] (must be tokenized with space and space change to '_')
  """
  
  paragraphs = tokenize_with_space(paragraphs_raw, tokenizer)
  summaries = tokenize_with_space(summaries_raw, tokenizer)
  start_position = []
  end_position = []
  texts_all = []
  for paragraph_raw, summary in zip(paragraphs, summaries):
    paragraph = paragraph_raw.split(" ")
    len_paragraph = len(paragraph)
    max_score = 0
    s = 0
    e = len_paragraph
    text = ""
    for length in range(1, length_sum_max):
      for start_pos in range(len_paragraph-length+1):
        t_summary = " ".join(paragraph[start_pos:start_pos+length])
        try:
          r1, r2, score = cal_rouge_score([summary], [t_summary])
          if(max_score < score):
            max_score = score
            s = start_pos
            e = start_pos + length
            text = "".join(paragraph[s:e])
        except:
          pass
    start_position.append(s)
    end_position.append(e)
    texts_all.append(text)
  return start_position, end_position, texts_all


In [None]:
import collections as coll
# stopwords = pkgutil.get_data(__package__, 'smart_common_words.txt')
# stopwords = stopwords.decode('ascii').split('\n')
# stopwords = {key.strip(): 1 for key in stopwords}

def _get_ngrams_count(n, text):
    """Calcualtes n-grams.
    Args:
      n: which n-grams to calculate
      text: An array of tokens
    Returns:
      A set of n-grams
    """
    ngram_dic = coll.defaultdict(int)
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_dic[tuple(text[i:i + n])] += 1
    return ngram_dic

def _get_ngrams(n, text):
    """Calcualtes n-grams.
    Args:
      n: which n-grams to calculate
      text: An array of tokens
    Returns:
      A set of n-grams
    """
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i:i + n]))
    return ngram_set

def _get_word_ngrams_list(n, text):
    """Calcualtes n-grams.
    Args:
      n: which n-grams to calculate
      text: An array of tokens
    Returns:
      A set of n-grams
    """
    text = sum(text, [])
    ngram_set = []
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.append(tuple(text[i:i + n]))
    return ngram_set

def _get_word_ngrams(n, sentences, do_count=False):
    """Calculates word n-grams for multiple sentences.
    """
    assert len(sentences) > 0
    assert n > 0

    # words = _split_into_words(sentences)

    words = sum(sentences, [])
    # words = [w for w in words if w not in stopwords]
    if do_count:
        return _get_ngrams_count(n, words)
    return _get_ngrams(n, words)
  
def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}

def selection_start_end_r1_r2(doc, abstract, tokenizer, summary_size = 50):
  """
  Select the start position and end postion for each paragraph to make a summary and maximize the Rouge-L score
  Args: 
  paragraphs [#number of paragraph, #number of word, #number of character] (must be tokenized with space and space change to '_')
  summaries [#number of summary, #number of word, #number of character] (must be tokenized with space and space change to '_')
  """
  
  max_rouge = 0.0
  tokenized_doc = tokenize_with_space([doc], tokenizer)[0].split(" ")
  tokenized_abstract = tokenize_with_space([abstract], tokenizer)[0].split(" ")
  # abstract = sum(abstract_sent_list, [])
  # abstract = ' '.join(abstract).split()
  # sents = [' '.join(s).split() for s in doc_sent_list]
  evaluated_1grams = _get_word_ngrams_list(1, [tokenized_doc])
  reference_1grams = _get_word_ngrams(1, [tokenized_abstract])
  evaluated_2grams = _get_word_ngrams_list(2, [tokenized_doc])
  reference_2grams = _get_word_ngrams(2, [tokenized_abstract])


  start = 0
  end = 0
  text = ""
  max_rouge = 0
  for s in range(1,summary_size):
      for i in range(len(tokenized_doc)-s+1):
          # if (i in selected):
          #     continue
          c = range(i,i+s)
          candidates_1 = set(evaluated_1grams[i:i+s])
          # candidates_1 = set.union(*map(set, candidates_1))
          rouge = cal_rouge(candidates_1, reference_1grams)['f']
          if(s > 1):
            candidates_2 = set(evaluated_1grams[i:i+s-1])
            rouge +=  cal_rouge(candidates_2, reference_2grams)['f']
          if rouge > max_rouge:
              max_rouge = rouge
              start = i
              end = i+s
              text = "".join(tokenized_doc[i:i+s])

  return start, end, text

## Preprocess data

In [None]:
!gdown --id 1-8IU8qyry-yPXwQ7AXz0GHIgn19QKGZP
!gdown --id 1-J0eqf4ig7cP8bMPRgSFUejshnBFTZoq
!gdown --id 1-IIJFl4AGNr7rRax4YSQTTm7j12YJ0ya

In [None]:
import pandas as pd
df = pd.read_csv('thaisum.csv')
val_df = pd.read_csv('validation_set.csv')
test_df = pd.read_csv('test_set.csv')
df = pd.concat([df, val_df, test_df], axis=0)

In [None]:
df = df.reset_index(drop=True)
df['body'][358868+11000]

In [None]:
def gold_summary(df, num_train, num_val, num_test):
  return df.iloc[num_train+num_val:num_train+num_val+num_test,:]['summary'].tolist()

In [None]:
def get_tokenized_df(df):
  df = df.reset_index(drop=True)
  res = pd.DataFrame(columns=['attention_mask', 'input_ids', 'start_positions', 'end_positions'])
  for i in tqdm(range(len(df))):
    sent1 = df['body'][i].lower()
    sent2 = df['summary'][i].lower()
    start, end, _ = selection_start_end_r1_r2(sent1, sent2, tokenizer)
    inp_ids = tokenizer(df['body'][i], max_length=416, truncation=True, padding='max_length')['input_ids']
    att_mask = tokenizer(df['body'][i], max_length=416, truncation=True, padding='max_length')['attention_mask']
    res = res.append({'attention_mask': att_mask, 
                      'input_ids': inp_ids, 
                      'start_positions': start, 
                      'end_positions': end}, ignore_index=True)
  return res
  '''
  return {'input_ids': res['input_ids'].tolist(),
          'attention_mask': res['attention_mask'].tolist(),
          'start_positions': res['start_positions'].tolist(),
          'end_positions': res['end_positions'].tolist()}
  '''

def get_tokenized_dict(df, num_train, num_val, num_test):
  train_df = df.iloc[:num_train, :]
  val_df = df.iloc[num_train:num_train+num_val, :]
  test_df = df.iloc[num_train+num_val:num_train+num_val+num_test, :]
  return {'train': get_tokenized_df(train_df),
          'validation': get_tokenized_df(val_df),
          'test': get_tokenized_df(test_df)}

def get_tokenized_dict_test_val(df, num_train, num_val, num_test):
  val_df = df.iloc[num_train:num_train+num_val, :]
  test_df = df.iloc[num_train+num_val:num_train+num_val+num_test, :]
  return {'validation': get_tokenized_df(val_df),
          'test': get_tokenized_df(test_df)}

def get_tokenized_dict_test(df, num_train, num_val, num_test):
  test_df = df.iloc[num_train+num_val:num_train+num_val+num_test, :]
  return {'test': get_tokenized_df(test_df)}

In [None]:
tokenize_with_space([df['body'][369868]], tokenizer)

Usually tokenizing takes a lot of time, you can choose to tokenize only some part of data by uncommenting.

In [None]:
# %%time
tokenized_datasets = get_tokenized_dict(df, 358868, 11000, 11000)
# tokenized_datasets = get_tokenized_dict_test_val(df, 358868, 11000, 11000)
# tokenized_datasets = get_tokenized_dict_test(df, 358868, 11000, 11000)

In [None]:
gold_summaries = gold_summary(df, 358868, 11000, 11000)

You can choose to save the data after preprocessing and load it.

In [None]:
# tokenized_datasets['train'].to_json('train.json', orient='records', lines=True)
# tokenized_datasets['validation'].to_json('/content/drive/MyDrive/validation_true_set.json', orient='records', lines=True)
# tokenized_datasets['test'].to_json('/content/drive/MyDrive/test_true_lower.json', orient='records', lines=True)
tokenized_datasets = load_dataset('json', data_files={'train': '/content/drive/MyDrive/train.json', 'validation': '/content/drive/MyDrive/validation_true.json', 'test': '/content/drive/MyDrive/test_true.json'})

In [None]:
tokenized_datasets

In [None]:
#8 in datasets['validation'] points to both 8 and 9 in tokenized_datasets['validation'] due to overflowing tokens
i = 8
example = tokenized_datasets['validation'][i]
combined_text = tokenizer.decode(example['input_ids'])
answer_with_token_idx = tokenizer.decode(example['input_ids'][example['start_positions']:example['end_positions']])

#there are quite a few more 
len(tokenized_datasets['validation']), answer_with_token_idx, combined_text

## Fine-tuning model

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained(
            f'airesearch/{model_name}' if model_name not in public_models else f'{model_name}',
            revision='main',)

In [None]:
batch_size = 16
learning_rate = 4e-5

args = TrainingArguments(
    f"finetune_thaiSum",
    evaluation_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size*2,
    num_train_epochs=6,
    warmup_ratio=0.15,
    weight_decay=0.01,
    fp16=True,
    save_total_limit=3,
    load_best_model_at_end=True,
)

data_collator = default_data_collator

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model("/content/drive/MyDrive/finetune_thaiSum4")

## Postprocess and metrics(BERTscore since rouge we already import at the beginning)

In [None]:
def post_process_index(data, raw_predictions, tokenizer, n_best_size = 20, max_answer_length=50):
  all_start_logits, all_end_logits = raw_predictions
  predictions = []
  for start_logits, end_logits, example in zip(all_start_logits, all_end_logits, data):
    start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
    end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
    valid_answers = []
    for start_index in start_indexes:
      for end_index in end_indexes:
          # Don't consider answers with a length that is either < 0 or > max_answer_length.
          if end_index < start_index or end_index - start_index + 1 > max_answer_length:
              continue
          valid_answers.append(
              {
                  "score": start_logits[start_index] + end_logits[end_index],
                  "text": tokenizer.decode(example['input_ids'][start_index+1:end_index+1], skip_special_tokens=True)
              }
          )
    if len(valid_answers) > 0:
        best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
    else:
        # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid failure.
        best_answer = {"text": "", "score": 0.0} 
    predictions.append(best_answer["text"])
  return predictions

### BERTScore

In [None]:
!pip install bert_score==0.3.7

In [None]:
from bert_score import score
import numpy as np
import gc

In [None]:
def cal_bert_score(cands, refs, get_average_f1=True):
  '''
  arguments: cands, refs
  return array of presicion, recall, f1, presicion_average, recall_average, f1_average
  if get_average == True return mean of BERTScore
  '''
  p, r, f1 = score(cands, refs, lang="others", verbose=False)
  p = p.numpy()
  r = r.numpy()
  f1 = f1.numpy()
  if(get_average_f1==True): return f1.mean()
  else: return p, r, f1

def cal_batch_bert_score(cands, refs, get_average_f1=True, batch_size=8):
  f1_average = []
  for i in tqdm(range(0,len(cands),batch_size)):
    cand_batch = cands[i:i+batch_size]
    ref_batch = refs[i:i+batch_size]
    res = cal_bert_score(cand_batch, ref_batch)
    f1_average.append(res)
    gc.collect()
  print(f1_average)
  return sum(f1_average)/len(f1_average)

In [None]:
%%time
refs = ['เมื่อวันที่ 6 ม.ค.60 ที่ทำเนียบรัฐบาล นายวิษณุ เครืองาม รองนายกรัฐมนตรี กล่าวถึงกรณี ที่ นายสุรชัย เลี้ยงบุญเลิศชัย รองประธานสภานิติบัญญัติแห่งชาติ (สนช.) ออกมาระบุว่า การเลือกตั้งจะถูกเลื่อนออกไปถึงปี 2561 ว่า ขอให้ไปสอบถามกับ สนช. แต่เชื่อว่าคงไม่กล้าพูดอีก เพราะทำให้คนเข้าใจผิด ซึ่งที่ สนช.พูดเนื่องจากผูกกับกฎหมายของกรรมการร่างรัฐธรรมนูญ(กรธ.) ตนจึงไม่ขอวิพากษ์วิจารณ์ แต่รัฐบาลยืนยันว่ายังเดินตามโรดแม็ป ซึ่งโรดแม็ปมองได้สองแบบ คือ มีลำดับขั้นตอนและการกำหนดช่วงเวลา โดยเริ่มต้นจากการประกาศใช้รัฐธรรมนูญ แต่ขณะนี้รัฐธรรมนูญยังไม่ประกาศใช้ จึงยังเริ่มนับหนึ่งไม่ถูก จากนั้นเข้าสู่ขั้นตอนการร่างกฎหมายประกอบร่างรัฐธรรมนูญหรือกฎหมายลูก ภายใน 240 วัน ก่อนจะส่งกลับให้ สนช.พิจารณา ภายใน 2 เดือน\xa0,นายวิษณุ กล่าวต่อว่า หากมีการแก้ไขก็จะมีการพิจารณาร่วมกับ กรธ.อีก 1 เดือน ก่อนนำขึ้นทูลเกล้าฯ ทรงลงพระปรมาภิไธย ภายใน 90 วัน และจะเข้าสู่การเลือกตั้งภายในระยะเวลา 5 เดือน ซึ่งทั้งหมดนี้คือโรดแม็ปที่ยังเป็นแบบเดิมอยู่ ส่วนเดิมที่กำหนดวันเลือกตั้งไว้ภายในปี 60 นั้น เพราะมาจากสมมติฐานของขั้นตอนเดิมทั้งหมด แต่เมื่อมีเหตุสวรรคตทุกอย่างจึงต้องเลื่อนออกไป ส่วนการพิจารณากฎหมายลูกทั้งหมด 4 ฉบับ ขณะนี้กรธ.พิจารณาแล้วเสร็จ 2 ฉบับ คือ พ.ร.ป.พรรคการเมือง และพ.ร.ป. คณะกรรมการการเลือกตั้ง แต่ พ.ร.ป.การเลือกตั้งควรจะพิจารณาได้เร็วกลับล่าช้า ดังนั้น กรธ.จะต้องออกชี้แจงถึงเหตุผลว่าทำไมพิจารณากฎหมายดังกล่าวล่าช้ากว่ากำหนด ส่งผลให้เกิดข้อสงสัยจนถึงทุกวันนี้ ส่วนกรณีที่ สนช. ระบุว่า มีกฎหมายเข้าสู่การพิจารณาของ สนช.เป็นจำนวนมาก ทำให้ส่งผลกระทบต่อโรดแม็ปนั้น รัฐบาลเคยบอกไว้แล้วว่าในช่วงนี้ของโรดแม็ปกฎหมายจะเยอะกว่าที่ผ่านมา ดังนั้น สนช.จะต้องบริหารจัดการกันเอง เพราะได้มีการเพิ่มสมาชิก สนช.ให้แล้ว.']
cands = ['เมื่อวันที่ 6 ม.ค.60 ที่ทำเนียบรัฐบาล นายวิษณุ เครืองาม รองนายกรัฐมนตรี กล่าวถึงกรณี ที่ นายสุรชัย เลี้ยงบุญเลิศชัย รองประธานสภานิติบัญญัติแห่งชาติ (สนช.)']
f1_average = cal_bert_score(cands, refs)

In [None]:
print(f1_average)

### Evaluate

In [None]:
def evaluate_rouge(cands, refs, tokenizer):
  cands_tokenized = tokenize_with_space(cands, tokenizer)
  refs_tokenized = tokenize_with_space(refs, tokenizer)
  r1, r2, rl = cal_rouge_score(refs_tokenized, cands_tokenized)
  return r1, r2, rl

In [None]:
raw_predictions = trainer.predict(tokenized_datasets['test'])

In [None]:
predictions = post_process_index(tokenized_datasets['test'], raw_predictions[0], tokenizer)

In [None]:
predictions[:3]

In [None]:
display(predictions[:3], gold_summaries[:3])

In [None]:
r1, r2, rl = evaluate_rouge(predictions, gold_summaries, tokenizer)
print(r1, r2, rl)

In [None]:
%%time
BERTScore = cal_batch_bert_score(predictions, gold_summaries, tokenizer, batch_size=128)

In [None]:
print(BERTScore)