In [None]:
!pip install pythainlp==3.0.5
!curl 'https://raw.githubusercontent.com/kokkoks/thai-joke-sentence-generator/main/siaw_caption.txt' > siaw_caption.txt

In [None]:
!pip install datasets transformers==4.11.2
!pip install sentencepiece

In [None]:
from datasets import load_dataset
from pythainlp import word_tokenize
# datasets = load_dataset("text", data_files={"train": path_to_train.txt, "validation": path_to_validation.txt}

#เตรียมข้อมูล

In [None]:
def txt_to_lst(file_path):
  stopword=open(file_path,"r")
  lines = stopword.read().split('\n')
  return lines

In [None]:
def export2txt(name,item):
  with open(name, 'w') as f:
    for i in item:
      f.write(' '.join(word_tokenize(i)))
      f.write('\n')

In [None]:
data = txt_to_lst('/content/siaw_caption.txt')
count = len(data)
result_percent = count*0.2
result_split = int(round(count - result_percent))
train = data[:result_split]
valid = data[result_split:]

In [None]:
export2txt('train.txt',train)
export2txt('valid.txt',valid)

In [None]:
datasets = load_dataset("text", data_files={"train": "/content/train.txt", 
                                            "validation":"/content/train.txt"})

In [None]:
datasets["train"][10]

In [None]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(datasets["train"])

In [None]:
model_checkpoint = "airesearch/wangchanberta-base-att-spm-uncased"
# model_checkpoint = "gpt2"
# model_checkpoint = "google/mt5-base"

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

In [None]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

In [None]:
tokenized_datasets["train"][1]

In [None]:
block_size = 128

In [None]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

In [None]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

#เตรียม Fine tune Model

In [None]:
from transformers import AutoModelForCausalLM,AutoModelForSeq2SeqLM,AutoModel

model = AutoModelForCausalLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
# model = AutoModelForCausalLM.from_pretrained("gpt2")
# model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-base")

In [None]:
from transformers import Trainer, TrainingArguments

In [None]:
model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    f"{model_name}-finetuned",
    evaluation_strategy = "epoch",
    num_train_epochs=200,
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=1,  
    per_device_eval_batch_size=1
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

In [None]:
trainer.train()

In [None]:
trainer.save_model('model-wangchanberta')
# trainer.save_model('model-gpt2')
# trainer.save_model('model-mT5')

#ทดสอบ Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch

tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
model = AutoModelForCausalLM.from_pretrained('model-wangchanberta')

# tokenizer = AutoTokenizer.from_pretrained("gpt2")
# model = AutoModelForCausalLM.from_pretrained('model-gpt2')

# tokenizer = AutoTokenizer.from_pretrained("google/mt5-base")
# model = AutoModelForSeq2SeqLM.from_pretrained('/model/model-mT5')

In [None]:
device = torch.device('cpu')

In [None]:
text = "ความรักก็เหมือน"

In [None]:
preprocess_text = text.strip().replace(" ","")
prepared_Text = preprocess_text
print("original text preprocessed: \n", preprocess_text)
tokenized_text = tokenizer.encode(prepared_Text,return_tensors="pt").to(device)

In [None]:
text_generate =  model.generate(tokenized_text,
                                    min_length=30,
                                    max_length=100,
                                    early_stopping=True)
output = tokenizer.decode(text_generate[0], skip_special_tokens=True,clean_up_tokenization_spaces=True)
print("original text: \n"+preprocess_text)
print("="*100)
print("generate text: \n"+output.strip().capitalize())

#Result

## WangChanBERTa
ความรักก็เหมือนชม ชมา   1  1 1 1 1 1 1 1 1 1 11 111111 111111111111111111111111111111111111111111111111111111111



## GPT-2
ความรักก็เหมือน หาม จั่ว   แต่ ถ้า เธอ มี ๆ   เรา จะ เจอ �



## google/mT5-base
ความรักก็เหมือน... ความรักก็เหมือน... ความรักก็เหมือน... ความรักก็เหมือน... ความรักก็เหมือน... ความรักก็เหมือน