# Datasets and Model training

### Libraly

In [None]:
!pip install transformers sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install datasets
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 KB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 KB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash
  Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212

In [None]:
import numpy as np
import pandas as pd
from datasets import load_dataset, DatasetDict, Dataset

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-564M")

Downloading (…)okenizer_config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/4.92M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/9.03M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/276 [00:00<?, ?B/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Datasets

In [None]:
# append all datasets to a list
data = []
data.append(pd.read_csv('/content/drive/MyDrive/Klong/Klong_haripunchai.csv')) 
data.append(pd.read_csv('/content/drive/MyDrive/Klong/Klong_lokanit.csv'))
data.append(pd.read_csv('/content/drive/MyDrive/Klong/Klong_supan.csv'))
data.append(pd.read_csv('/content/drive/MyDrive/Klong/Klong_wadruak.csv'))
data.append(pd.read_csv('/content/drive/MyDrive/Klong/klong_jaofahapai.csv'))
klong_dataset_list = []
for klong in range(5):
  for rowidx in range(data[klong].shape[0]):
    klong_dataset_list.append('\n'.join(data[klong].iloc[rowidx])) # join with \n

In [None]:
# split train valid function 
def split_data(data:list, test_split_ratio:float) -> tuple:
  split_calculation = 100 - (round(len(data) * test_split_ratio)) # need to be int
  valid = data[:split_calculation]
  train = data[split_calculation:]
  return train, valid

In [None]:
# split train valid 0.2
klong_train, klong_valid = split_data(klong_dataset_list, 0.2)
# to datasets
klong_train = Dataset.from_dict({"content": klong_train})
klong_valid = Dataset.from_dict({"content": klong_valid})

In [None]:
# shuffle data
raw_datasets = DatasetDict(
    {
        "train": klong_train.shuffle(),  # .shuffle().select(range(50000)),
        "valid": klong_valid.shuffle(),  # .shuffle().select(range(500))
    }
)

#### Tokenizing

In [None]:
# tokenize data
def tokenize(element, context_length=128): # context_length will cut of
  outputs = tokenizer(
      element["content"],
      truncation=True,
      max_length=context_length,
      return_overflowing_tokens=True,
      return_length=True,
  )
  print(outputs.keys())
  input_batch = []
  for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
      input_batch.append(input_ids) # if enabled it will ignore token lenght shorter than context_length
  return {"input_ids": input_batch}

In [None]:
tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)

Map:   0%|          | 0/165 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'attention_mask', 'length', 'overflow_to_sample_mapping'])


Map:   0%|          | 0/1161 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'attention_mask', 'length', 'overflow_to_sample_mapping'])
dict_keys(['input_ids', 'attention_mask', 'length', 'overflow_to_sample_mapping'])


### Modeling

In [None]:
from transformers import DataCollatorForLanguageModeling

In [None]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
try:
  out = data_collator([tokenized_datasets["train"][i] for i in range(len(tokenized_datasets["train"]))])
except:
  out = data_collator(tokenized_datasets["train"])
for key in out:
    print(f"{key} shape: {out[key].shape}")

You're using a XGLMTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


input_ids shape: torch.Size([165, 66])
attention_mask shape: torch.Size([165, 66])
labels shape: torch.Size([165, 66])


### Training


In [None]:
model = AutoModelForCausalLM.from_pretrained("facebook/xglm-564M")

In [None]:
# log in hugging face
from huggingface_hub import notebook_login
notebook_login() 

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="/content/drive/MyDrive/klong2",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    gradient_accumulation_steps=8,
    num_train_epochs=100,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    push_to_hub=True,
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

In [None]:
# train model
trainer.train()

# Check and generate


### Khavee and check eak tou 

In [None]:
kv = KhaveeVerifier()

class KhaveeVerifier:
    def check_sara(self,word):
        sara = []
        countoa = 0
        # In case การันย์
        if '์' in word[-1]:
            word = word[:-2]
        # In case สระเดี่ยว
        for i in word:
            if i == 'ะ' or i == 'ั':
                sara.append('อะ')
            elif i == 'ิ':
                sara.append('อิ')
            elif i == 'ุ':
                sara.append('อุ')
            elif i == 'ึ':
                sara.append('อึ')
            elif i == 'ี':
                sara.append('อี')
            elif i == 'ู':
                sara.append('อู')
            elif i == 'ื':
                sara.append('อือ')
            elif i == 'เ':
                sara.append('เอ')
            elif i == 'แ':
                sara.append('แอ')
            elif i == 'า':
                sara.append('อา') 
            elif i == 'โ':
                sara.append('โอ')
            elif i == 'ำ':
                sara.append('อำ')
            elif i == 'อ':
                countoa += 1
                sara.append('ออ')
            elif i == 'ั' and 'ว' in word:
                sara.append('อัว')
            elif i == 'ไ' or i == 'ใ':
                sara.append('ไอ') 
            elif 'รร' in word:
                if self.check_marttra(word) == 'กม':
                    sara.append('อำ')
                else:
                    sara.append('อะ')
        # Incase ออ
        if countoa == 1 and 'อ' in word[-1]:
            sara.remove('ออ')
        # In case เอ เอ 
        countA = 0
        for i in sara:
            if i == 'เอ':
                countA = countA + 1
            if countA > 1:
                sara.remove('เอ')
                sara.remove('เอ')
                sara.append('แ')
        # In case สระประสม
        if 'เอ' in sara and 'อะ' in sara:
            sara.remove('เอ')
            sara.remove('อะ')
            sara.append('เอะ')
        elif 'แอ' in sara and 'อะ' in sara:
            sara.remove('แอ')
            sara.remove('อะ')
            sara.append('แอะ')
        if 'เอะ' in sara and 'ออ' in sara:
            sara.remove('เอะ')
            sara.remove('ออ')
            sara.append('เออะ')
        elif 'เอ' in sara and 'อิ' in sara:
            sara.remove('เอ')
            sara.remove('อิ')
            sara.append('เออ')        
        elif 'เอะ' in sara and 'อา' in sara:
            sara.remove('เอะ')
            sara.remove('ออ')
            sara.append('เอาะ')
        elif 'เอ' in sara and 'ออ' in sara and 'อ' in word[-1]:
            sara.remove('เอ')
            sara.remove('ออ')
            sara.append('เออ')
        elif 'โอ' in sara and 'อะ' in sara: 
            sara.remove('โอ')
            sara.remove('อะ')
            sara.append('โอะ')
        elif 'เอ' in sara and 'อี' in sara: 
            sara.remove('เอ')
            sara.remove('อี')
            sara.append('เอีย')
        elif 'เอ' in sara and 'อือ' in sara: 
            sara.remove('เอ')
            sara.remove('อือ')
            sara.append('อัว')   
        elif 'เอ' in sara and 'อา' in sara: 
            sara.remove('เอ')
            sara.remove('อา')
            sara.append('เอา') 
        if 'อือ' in sara and 'เออ' in sara: 
            sara.remove('เออ')
            sara.remove('อือ')
            sara.append('เอือ')  
        elif 'ออ' in sara and len(sara) > 1:
            sara.remove('ออ')        
        elif 'ว' in word and len(sara) == 0:
            sara.append('อัว')
        if 'ั' in word and self.check_marttra(word) == 'กา':
            sara = []
            sara.append('ไอ')
        # In case อ
        if word == 'เออะ':
            sara = []
            sara.append('เออะ')
        elif word == 'เออ':
            sara = []
            sara.append('เออ')
        elif word == 'เอ':
            sara = []
            sara.append('เอ')
        elif word == 'เอะ':
            sara = []
            sara.append('เอะ')
        elif word == 'เอา':
            sara = []
            sara.append('เอา')
        if 'ฤา' in word or 'ฦา' in word:
            sara = []
            sara.append('อือ') 
        elif 'ฤ' in word or 'ฦ' in word:
            sara = []
            sara.append('อึ') 
        # In case กน
        if sara == [] and len(word) == 2:
            if word[-1] != 'ร':
                sara.append('โอะ')
            else:
                sara.append('ออ') 
        elif sara == [] and len(word) == 3:
            sara.append('ออ') 
        if sara == []:
            return 'Cant find Sara in this word'
        else:
            return sara[0]


    def check_marttra(self,word):
        if word[-1] == 'ร' and word[-2] in ['ต','ท'] :
            word = word[:-1]
            # print(word)
        if '์' in word[-1]:
            if 'ิ' in word[-2] or 'ุ' in word[-2]:
                word = word[:-3]
            else:
                word = word[:-2]
        if 'ำ' in word or ('ํ' in word and 'า' in word) or 'ไ' in word or 'ใ' in word:
            return 'กา'
        elif word[-1] in ['า','ะ','ิ','ี','ุ','ู','อ'] or ('ี' in word and 'ย' in word[-1]) or ('ื' in word and 'อ' in word[-1]):
            return 'กา'
        elif word[-1] in ['ง']:
            return 'กง'
        elif word[-1] in ['ม']:
            return 'กม'
        elif word[-1] in ['ย']:
            if 'ั' in word:
                return 'กา'
            else:
                return 'เกย'
        elif word[-1] in ['ว']:
            return 'เกอว'
        elif word[-1] in ['ก','ข','ค','ฆ']:
            return 'กก'
        elif word[-1] in ['จ','ช','ซ','ฎ','ฏ','ฐ','ฑ','ฒ','ด','ต','ถ','ท','ธ','ศ','ษ','ส'] :
            return 'กด'
        elif word[-1] in ['ญ',', ณ' ,'น' ,'ร' ,'ล' ,'ฬ']:
            return 'กน'
        elif word[-1] in ['บ', 'ป', 'พ', 'ฟ', 'ภ']:
            return 'กบ'
        else:
           return 'Cant find Marttra in this word'

    def check_sumpus(self,word1,word2):
        marttra1 = self.check_marttra(word1)
        marttra2 = self.check_marttra(word2)
        sara1 = self.check_sara(word1)
        sara2 = self.check_sara(word2)
        if sara1 == 'อะ' and marttra1 == 'เกย':
            sara1 = 'ไอ'
            marttra1 = 'กา'
        elif sara2 == 'อะ' and marttra2 == 'เกย':
            sara2 = 'ไอ'
            marttra2 = 'กา'
        if sara1 == 'อำ' and marttra1 == 'กม':
            sara1 = 'อำ'
            marttra1 = 'กา'
        elif sara2 == 'อำ' and marttra2 == 'กม':
            sara2 = 'อำ'
            marttra2 = 'กา'
        # print(marttra1,marttra2)
        # print(sara1,sara2)
        if marttra1 == marttra2 and sara1 == sara2:
            return True
        else:
            return False

    def check_klon(self,text,k_type=8):
        if k_type == 8:
            try:
                error = []
                list_sumpus_sent1 = []
                list_sumpus_sent2h = []
                list_sumpus_sent2l = []
                list_sumpus_sent3 = []
                list_sumpus_sent4 = []
                for i, sent in enumerate(text.split()):
                    sub_sent = subword_tokenize(sent, engine='dict')
                    # print(i)
                    if len(sub_sent) > 10:
                        error.append('In the sentence'+str(i+2)+'there are more than 10 words.'+str(sub_sent))
                    if (i+1) % 4 == 1:
                        list_sumpus_sent1.append(sub_sent[-1])
                    elif (i+1) % 4 == 2:
                        list_sumpus_sent2h.append([sub_sent[1],sub_sent[2],sub_sent[3],sub_sent[4]])
                        list_sumpus_sent2l.append(sub_sent[-1])
                    elif (i+1) % 4 == 3:
                        list_sumpus_sent3.append(sub_sent[-1])
                    elif (i+1) % 4 == 0:
                        list_sumpus_sent4.append(sub_sent[-1])
                if len(list_sumpus_sent1) != len(list_sumpus_sent2h) or len(list_sumpus_sent2h) != len(list_sumpus_sent2l) or len(list_sumpus_sent2l) != len(list_sumpus_sent3) or len(list_sumpus_sent3) != len(list_sumpus_sent4)  or len(list_sumpus_sent4) != len(list_sumpus_sent1):
                    return 'The poem does not complete 4 sentences.'
                else:
                    for i in range(len(list_sumpus_sent1)):
                        countwrong = 0
                        for j in list_sumpus_sent2h[i]:
                            if self.check_sumpus(list_sumpus_sent1[i],j) == False:
                                    countwrong +=1
                        if  countwrong > 3:
                            error.append('Cant find rhyme between paragraphs '+str((list_sumpus_sent1[i],list_sumpus_sent2h[i]))+'in paragraph '+str(i+1))
                        if self.check_sumpus(list_sumpus_sent2l[i],list_sumpus_sent3[i]) == False:
                            # print(sumpus_sent2l,sumpus_sent3)
                            error.append('Cant find rhyme between paragraphs '+str((list_sumpus_sent2l[i],list_sumpus_sent3[i]))+'in paragraph '+str(i+1))
                        if i > 0:
                            if self.check_sumpus(list_sumpus_sent2l[i],list_sumpus_sent4[i-1]) == False:
                                error.append('Cant find rhyme between paragraphs '+str((list_sumpus_sent2l[i],list_sumpus_sent4[i-1]))+'in paragraph '+str(i+1))
                    if error == []:
                        return 'The poem is correct according to the principle.'
                    else:
                        return error
            except:
                return 'Something went wrong Make sure you enter it in correct form of klon4.'
        elif k_type == 4:
            try:
                error = []
                list_sumpus_sent1 = []
                list_sumpus_sent2h = []
                list_sumpus_sent2l = []
                list_sumpus_sent3 = []
                list_sumpus_sent4 = []
                for i, sent in enumerate(text.split()):
                    sub_sent = subword_tokenize(sent, engine='dict')
                    if len(sub_sent) > 5:
                        error.append('In the sentence'+str(i+2)+'there are more than 4 words.'+str(sub_sent))
                    if (i+1) % 4 == 1:
                        list_sumpus_sent1.append(sub_sent[-1])
                    elif (i+1) % 4 == 2:
                        # print([sub_sent[1],sub_sent[2]])
                        list_sumpus_sent2h.append([sub_sent[1],sub_sent[2]])
                        list_sumpus_sent2l.append(sub_sent[-1])
                    elif (i+1) % 4 == 3:
                        list_sumpus_sent3.append(sub_sent[-1])
                    elif (i+1) % 4 == 0:
                        list_sumpus_sent4.append(sub_sent[-1])
                if len(list_sumpus_sent1) != len(list_sumpus_sent2h) or len(list_sumpus_sent2h) != len(list_sumpus_sent2l) or len(list_sumpus_sent2l) != len(list_sumpus_sent3) or len(list_sumpus_sent3) != len(list_sumpus_sent4)  or len(list_sumpus_sent4) != len(list_sumpus_sent1):
                    return 'The poem does not complete 4 sentences.'
                else:
                    for i in range(len(list_sumpus_sent1)):
                        countwrong = 0
                        for j in list_sumpus_sent2h[i]:
                            # print(list_sumpus_sent1[i],j)
                            if self.check_sumpus(list_sumpus_sent1[i],j) == False:
                                    countwrong +=1
                        if  countwrong > 1:
                            error.append('Cant find rhyme between paragraphs '+str((list_sumpus_sent1[i],list_sumpus_sent2h[i]))+'in paragraph '+str(i+1))
                        if self.check_sumpus(list_sumpus_sent2l[i],list_sumpus_sent3[i]) == False:
                            # print(sumpus_sent2l,sumpus_sent3)
                            error.append('Cant find rhyme between paragraphs '+str((list_sumpus_sent2l[i],list_sumpus_sent3[i]))+'in paragraph '+str(i+1))
                        if i > 0:
                            if self.check_sumpus(list_sumpus_sent2l[i],list_sumpus_sent4[i-1]) == False:
                                error.append('Cant find rhyme between paragraphs '+str((list_sumpus_sent2l[i],list_sumpus_sent4[i-1]))+'in paragraph '+str(i+1))
                    if error == []:
                        return 'The poem is correct according to the principle.'
                    else:
                        return error
            except:
                return 'Something went wrong Make sure you enter it in correct form.'
            
        else:
            return 'Something went wrong Make sure you enter it in correct form.'






def check_aek_too(text: str or list[str]) -> str or bool or list[str or bool]:
    """
    Check if the word is aek or too or False(not both)
    :param str or list[str] text: Thai word or list of Thai words
    :return: the check if the word is aek or too or False(not both) or list of the check if input is list
    :rtype: str or bool or list[str or bool]
    :Example:
    ::
        from pythainlp.khavee import KhaveeVerifier
        kv = KhaveeVerifier()
        # การเช็คคำเอกโท
        print(kv.check_aek_too('เอง'), kv.check_aek_too('เอ่ง'), kv.check_aek_too('เอ้ง'))
        ## -> False, aek, too
        print(kv.check_aek_too(['เอง', 'เอ่ง', 'เอ้ง'])) # ใช้ List ได้เหมือนกัน
        ## -> [False, 'aek', 'too']
    """
    if isinstance(text, list):
        return [self.check_aek_too(t) for t in text]

    if not isinstance(text, str):
        raise TypeError('text must be str or iterable list[str]')

    word_characters = [*text]
    if '่' in word_characters and not '้' in word_characters:
        return 'aek'
    elif '้' in word_characters and not '่' in word_characters:
        return 'too'
    else:
        return False

In [None]:
kv = KhaveeVerifier()

In [3]:
# install pythainlp and ssg(subword tokenizer)
!pip install pythainlp
!pip install ssg

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pythainlp
  Downloading pythainlp-3.1.1-py3-none-any.whl (9.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pythainlp
Successfully installed pythainlp-3.1.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ssg
  Downloading ssg-0.0.8-py3-none-any.whl (473 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m473.8/473.8 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
Collecting fire>=0.1.3
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 KB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting python-crfsuite>=0.9.6
  Downloading python_crfsuite-0.9.9-cp39-cp39-manylinux_2_17_x8

In [4]:
import pythainlp as pythai
from pythainlp.tokenize import word_tokenize
from pythainlp.tokenize import subword_tokenize
from pythainlp.util import sound_syllable
from pythainlp.util import isthai
from pythainlp.transliterate import pronunciate
from tqdm import tqdm
import numpy as np
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Transformers
!pip install transformers sentencepiece
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("/content/drive/MyDrive/klong", local_files_only=True)
model = AutoModelForCausalLM.from_pretrained("/content/drive/MyDrive/klong", local_files_only=True)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Word and Subword Tokenizing

In [None]:
# split text from \n to list and drop soi word ->  splitted wak list (no soi)
def split_klong(klong_text):
  splitted_klong = []
  klong_list = klong_text.split('\n')
  klong_list = [klong for klong in klong_list if klong.strip()]
  for i in range(len(klong_list)):
    if i == 1 or i == 3 or i == 5: 
      klong = klong_list[i]
      if klong[0] == ' ': 
        klong = klong[1:]
      klong = klong.split(' ')
      splitted_klong.append(klong[0])
    else:
      splitted_klong.append(klong_list[i].replace(' ', ''))
  return splitted_klong

In [None]:
# subword tokenize wak with ssg and dict
def subword_token(wak, engine='ssg'):
  subword_tokenized = subword_tokenize(wak, engine='ssg')
  if len(subword_tokenized) != 5 and len(subword_tokenized) != 2:
      subword_tokenized = subword_tokenize(wak, engine='dict')
  return subword_tokenized

### Check Functions

#### Number of syllables check

In [None]:
# check number of syllables -> [True, True, True, True, True, True, True, True] (len=8)
def subword_num(splitted_klong):
  checked = []
  two = [1,3,5]
  five = [0,2,4,6]
  for num in range(len(splitted_klong)):
    if num in two:
      checked.append(len(subword_token(splitted_klong[num])) == 2)
    elif num in five: 
      checked.append(len(subword_token(splitted_klong[num])) == 5)
    elif num == 7:
      checked.append(len(subword_token(splitted_klong[num])) == 4)
  return checked

#### eak tou check


In [None]:
# check what word tone is 
def find_tone(word):
  char_list = [*word]
  if "่" in char_list or sound_syllable(word) == 'dead':
    return "eak or dead"
  elif "้" in char_list:
    return "tou"
  else:
    return False

In [None]:
# check eaktou -> list[True, True, True, True, True, True, True, True] (len=8)
def check_eaktou(splitted_klong):
  checked = []
  for num in range(len(splitted_klong)):
    tokenzied_wak = subword_token(splitted_klong[num])
    if num == 0:
      checked.append(find_tone(tokenzied_wak[3]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou')
    elif num == 1:
      checked.append(True)
    elif num == 2:
      checked.append(find_tone(tokenzied_wak[1]) == "eak or dead")
    elif num == 3:
      checked.append(find_tone(tokenzied_wak[0]) == 'eak or dead' and find_tone(tokenzied_wak[1]) == 'tou')
    elif num == 4:
      checked.append(find_tone(tokenzied_wak[2]) == 'eak or dead')
    elif num == 5:
      checked.append(find_tone(tokenzied_wak[1]) == 'eak or dead')
    elif num == 6:
      checked.append(find_tone(tokenzied_wak[1]) == "eak or dead" and find_tone(tokenzied_wak[4]) == 'tou')
    elif num == 7:
      checked.append(find_tone(tokenzied_wak[0]) == "eak or dead" and find_tone(tokenzied_wak[1]) == 'tou')
  return checked

#### sampas check

In [None]:
# last sound of wak from pronunciate tokenized last word of each wak
# ex [เสียงลือเสียงเล่าอ้าง] -> [อ้าง]
def sound_words(splitted_klong):
  sound_list = []
  for wak in splitted_klong:
    list_char = [*wak]
    if " " in list_char:
      wak = wak.split(" ")
      wak = wak[0]
    wak = word_tokenize(wak, engine="newmm")
    pronounce_word = pronunciate(wak[-1], engine="w2p")
    sound_list.append(pronounce_word.replace('ฺ', '').split('-')[-1])
  return sound_list

In [None]:
# check sampas -> [True, True, True] 
# [0] = sampas wak 2-3, [1] = sampas wak 2-4, [2] sampas wak 4-7
def check_sampas(sound_list):
  checked = []
  if len(sound_list) > 2:
    checked.append(kv.check_sumpus(sound_list[1],sound_list[2]))
    if len(sound_list) > 4:
      checked.append(kv.check_sumpus(sound_list[1],sound_list[4]))
      if len(sound_list) > 6:
        checked.append(kv.check_sumpus(sound_list[3],sound_list[6]))
  else:
    checked.append(True)
  return checked

#### Main Check

In [None]:
def main_check(klong_text):
  splitted_klong = split_klong(klong_text)
  checked_subword_num = subword_num(splitted_klong)
  if False in checked_subword_num:
    false_index = checked_subword_num.index(False)
    return 'syllable format error', false_index+1
  else:
    checked_eaktou = check_eaktou(splitted_klong)
    if False in checked_eaktou:
      false_index = checked_eaktou.index(False)
      return 'eaktou format error', false_index+1
    else:
      sound_list = sound_words(splitted_klong)
      checked_sampas = check_sampas(sound_list)
      if False in checked_sampas:
        wak_sampas = ['2 and 3', '2 and 5', '4 and 7']
        return 'sampas format error', wak_sampas[checked_sampas.index(False)]
      else:
        return True

### Generate Klong

In [7]:
input_text = 'เสียงลือเสี่ยงเล่าอ้าง\nอันใด พี่เอย\nเสี่ยงย่อมยอยศใคร\nทั่วหล้า\nสองเขือพี่หลับใหล\nลืมตื่น ฤาพี่\nสองพี่คิดเองอ้า\nอย่าได้ถามเผือ'
def gen_prob_next_token(text:str, model, tokenizer):
  input_ids = tokenizer(input_text, return_tensors="pt")
  #look at tensor shape
  input_ids,input_ids['input_ids'].shape 

  #get logit of the next token
  outputs = model(input_ids['input_ids'])
  logits = outputs.logits
  logits.shape #the size is equal to input token because it's predicting the next one

  #convert logit to prob; use the logits of the last input token
  import torch.nn.functional as F
  probs = F.softmax(logits[:, -1, :], dim=-1).squeeze() 
  probs, probs.argmax()

  #match prob with vocab
  import pandas as pd
  df = pd.DataFrame(tokenizer.vocab.items(), columns=['token', 'token_id']).sort_values('token_id').reset_index(drop=True)
  
  df['prob'] = probs.detach().numpy()

  possible_token = df.sort_values('prob',ascending=False).reset_index()
  thai_only = [x if isthai(x) else None for x in possible_token['token']] # thai only
  possible_token['token'] = thai_only
  possible_token = possible_token.dropna()
  return possible_token

last_sampas = word_tokenize(input_text.split(' ')[0])[-1]
last_sampas = pronunciate(last_sampas).split('-')[-1]
prob = gen_prob_next_token(input_text, model, tokenizer) # prob คือคำที่เป็นไปได้ทั้งหมด

input_ids {'input_ids': tensor([[     2, 173710, 141086, 151704,  64302,  94899,  96039,  20542, 101239,
          13996,   2300,      6, 151704,  55652,   1731,   2300,  23606,   4900,
          19831,      6,  74336,  61044,  18753,      6,  22769,  29837,  33754,
          32220,   3987,  36160, 192006,   2940,      6, 116017,  87659,      6,
          63907,   6964,  32220,      6,  22769,  32220,  18936,  17246,   6325,
          18753,  93998,   1621,  35581,  43965,  33754]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]])}


In [None]:
def gen_rules(probs):
  passed = []
  for prob in probs:
    if len(prob) != 1 and len(subword_token(prob)) == 1 and '-' not in pronunciate(prob):
        passed.append(prob)
  return passed

In [None]:
def get_sampassed(data:list, sampaswith):
  passed = []
  counter_exception = 0
  for possible_word in tqdm(data):
    possible_sampas = pronunciate(possible_word).split('-')[-1] # reduce word dimension
    sampaswith = pronunciate(sampaswith).split('-')[-1] # reduce word dimension
    try:
      if kv.check_sumpus(possible_sampas, sampaswith):
        passed.append(possible_word)
    except:
      counter_exception += 1
      continue
  return passed

In [None]:
# aek or too
def get_aek_too(data:list, ktype='aek'):
  passed = []
  for possible_word in tqdm(data):
      if check_aek_too(possible_word) == ktype:
        passed.append(possible_word)
  return passed

In [None]:
def tone_gen(klong_text, gened_word, word_mark=None, sampas=False):
  splitted_klong = split_klong(klong_text)
  if word_mark == None and sampas == False:
     probs = generator(klong_text)
     for prob in probs:
       if prob not in gened_word:
         gened_word.append(prob)
         return prob, gened_word
  elif word_mark == 'aek' and sampas == False:
    probs = generator(klong_text)
    aek = get_aek_too(probs)
    for prob in aek:
      if prob not in gened_word:
        gened_word.append(prob)
        return prob, gened_word
  elif word_mark == 'too' and sampas == False:
    probs = generator(klong_text)
    too = get_aek_too(probs, 'too')
    for prob in too:
      if prob not in gened_word:
        gened_word.append(prob)
        return prob, gened_word
  elif sampas == True and word_mark == None:
    probs = generator(klong_text)
    passed = get_sampassed(probs, sound_words(splitted_klong)[1])
    for prob in passed:
       if prob not in gened_word:
         gened_word.append(prob)
         return prob, gened_word
  elif sampas == True and word_mark == 'too':
    probs = generator(klong_text)
    passed = get_sampassed(probs, sound_words(splitted_klong)[3])
    for prob in passed:
       if prob not in gened_word and check_aek_too(prob) == 'too':
         gened_word.append(prob)
         return prob, gened_word

In [None]:
def gen_klong(klong_text, gened_word):
  splitted_klong = split_klong(klong_text)
  # วรรค 2, 4, 6
  if len(splitted_klong) in [1, 3, 5]:
    word_gen = 2
    if len(splitted_klong) == 1:
      # ฉันทลักษณ์ (none, none(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      klong_text = klong_text + '\n'
    elif len(splitted_klong) == 3:
      # ฉันทลักษณ์ (aek, too(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, 'too')
      klong_text = klong_text + prob
      klong_text = klong_text + '\n'
    elif len(splitted_klong) == 5:
      # ฉันทลักษณ์ (none, aek)
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      klong_text = klong_text + '\n'

  # วรรค 3, 5, 7
  elif len(splitted_klong) in [2, 4, 6]:
    word_gen = 5
    if len(splitted_klong) == 2:
      # ฉันทลักษณ์ (none, aek, none, none, none(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      sampas_word = sound_words(splitted_klong)[1]
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark=None, sampas=True)
      klong_text = klong_text + prob
      klong_text = klong_text + '\n'
    elif len(splitted_klong) == 4:
      # ฉันทลักษณ์ (none, none, aek, none, none(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      sampas_word = sound_words(splitted_klong)[1]
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark=None, sampas=True)
      klong_text = klong_text + prob
      klong_text = klong_text + '\n'
    elif len(splitted_klong) == 6:
      # ฉันทลักษณ์ (none, aek, none, none, too(sampas))
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      prob, gened_word  = tone_gen(klong_text, gened_word)
      klong_text = klong_text + prob
      sampas_word = sound_words(splitted_klong)[1]
      prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='too', sampas=True)
      klong_text = klong_text + prob
      klong_text = klong_text + '\n'
  # วรรค 8
  elif len(splitted_klong) == 7:
    # ฉันทลักษณ์ (eak, too, none, none)
    word_gen = 4
    prob, gened_word  = tone_gen(klong_text, gened_word, word_mark='aek')
    klong_text = klong_text + prob
    prob, gened_word  = tone_gen(klong_text, gened_word, 'too')
    klong_text = klong_text + prob
    prob, gened_word  = tone_gen(klong_text, gened_word)
    klong_text = klong_text + prob
    prob, gened_word  = tone_gen(klong_text, gened_word)
    klong_text = klong_text + prob
    klong_text = klong_text + '\n'
  return klong_text, gened_word

In [None]:
def generator(klong):
  prob = gen_prob_next_token(klong, model, tokenizer)
  new_prob = gen_rules(prob['token'].tolist())
  return new_prob

### Main Function


In [None]:
def main(klong_text):
  gened_klong = []
  splitted = split_klong(klong_text)
  if main_check(klong_text) == True:
    wak_num = len(splitted)
    for i in range(8-wak_num):
      klong_text, gened_klong = gen_klong(klong_text, gened_klong)
    return klong_text
  else:
    return main_check(klong_text)

# MAIN

In [None]:
input_text = 'เสียงลือเสียงเล่าอ้าง\n'
main(input_text)

100%|██████████| 1521/1521 [00:00<00:00, 364586.60it/s]
100%|██████████| 1521/1521 [00:17<00:00, 84.62it/s] 
100%|██████████| 1521/1521 [00:00<00:00, 263497.43it/s]
100%|██████████| 1521/1521 [00:00<00:00, 193583.26it/s]
100%|██████████| 1521/1521 [00:00<00:00, 244286.29it/s]
100%|██████████| 1521/1521 [00:18<00:00, 80.65it/s]
100%|██████████| 1521/1521 [00:00<00:00, 381414.35it/s]
100%|██████████| 1521/1521 [00:00<00:00, 403002.93it/s]
100%|██████████| 1521/1521 [00:16<00:00, 92.90it/s] 
100%|██████████| 1521/1521 [00:00<00:00, 386919.97it/s]
100%|██████████| 1521/1521 [00:00<00:00, 273811.60it/s]


'บรรเลงเพลงขับร้อง\n\nกลอก\nชายเล่ฟังนพดอก\nแผ่นต้น\nเรียนึกเยี่ยมกลางรอก\nหวสั่ง\nหญิงหนึ่งโกนองค้น\nพ่อแก้หยอร\n'