# Thai Story Generator with pretrained model

In [None]:
# !pip install datasets

## Import datasets

In [1]:
from datasets import load_dataset
ds_train = load_dataset("thaisum", split="train")


Reusing dataset thaisum (/home/i7g7-1080ti/.cache/huggingface/datasets/thaisum/thaisum/1.0.0/347b33c852af4d796e1224e00e15142d626608a9fa3e07ad6d19dfd8fcae5423)


In [2]:
ds_valid = load_dataset("thaisum", split="validation")

Reusing dataset thaisum (/home/i7g7-1080ti/.cache/huggingface/datasets/thaisum/thaisum/1.0.0/347b33c852af4d796e1224e00e15142d626608a9fa3e07ad6d19dfd8fcae5423)


In [3]:
from datasets import DatasetDict

In [4]:
raw_datasets = DatasetDict(
    {
        "train": ds_train.shuffle().select(range(50000)),
        "valid": ds_valid.shuffle().select(range(500))
    }
)

In [5]:
for key in raw_datasets["train"][0]:
    print(f"{key.upper()}: {raw_datasets['train'][0][key][:200]}")

TITLE: จนท.ทหารสนธิกำลังจับ อดีต ส.จ.น่าน ขับกระบะลอบขนไม้สักทองในอช.ศรีน่าน
BODY: เมื่อวันที่ 8 ธันวาคม 59 เวลาประมาณ 05.30 น. พ.อ.รุศมนตรี จิณเสน ผอ.กอ.รมน.จ.น่าน, นายดงพล รุจิธรรมธัช นายอำเภอเวียงสา, นายเจริญ โพธิรินท์ ผู้ช่วยหัวหน้าอุทยานแห่งชาติศรีน่าน ได้สนธิกำลัง จนท.ทหาร ชก.
SUMMARY: ทหาร-จนท.ปกครอง จ.น่าน สนธิกำลังรวบอดีต ส.จ.2 สมัย ดัดแปลงรถขนไม้สักทอง บริเวณในเขตอุทยานแห่งชาติศรีน่าน นำตัวส่งสภ.น้ำมวบ ดำเนินคดี พร้อมยื่นคัดค้านการประกันตัว เนื่องจากเป็นผู้มีอิทธิพลในพื้นที่
TYPE: ข่าว,ทั่วไทย
TAGS: อุทยานแห่งชาติศรีน่าน,ชนาประทิน อุ่นถา,น่าน,ลักลอบขนไม้สัก,อดีต ส.จ.
URL: https://www.thairath.co.th/news/local/805797


## Preprocessing

In [6]:
from transformers import AutoTokenizer

context_length = 128
tokenizer = AutoTokenizer.from_pretrained("flax-community/gpt2-base-thai")

outputs = tokenizer(
    raw_datasets["train"][:2]["body"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

Input IDs length: 18
Input chunk lengths: [128, 128, 128, 128, 128, 52, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 124]
Chunk mapping: [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [7]:
def tokenize(element):
    outputs = tokenizer(
        element["body"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

  0%|          | 0/50 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 463424
    })
    valid: Dataset({
        features: ['input_ids'],
        num_rows: 6972
    })
})

In [8]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "flax-community/gpt2-base-thai",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [9]:
model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

GPT-2 size: 124.4M parameters


In [10]:
keytoken_ids = []
for keyword in [
    "สวัสดี",
    "ผม",
    "ชอบ",
    "กิน",
    "รัก",
    "อร่อย",
    "หอม",
    "ตู่",
    "ถาม",
    "คิดถึง",
    "คิดถึงจังอิ",
]:
    ids = tokenizer([keyword]).input_ids[0]
    if len(ids) == 1:
        keytoken_ids.append(ids[0])
    else:
        print(f"Keyword has not single token: {keyword}")

Keyword has not single token: สวัสดี
Keyword has not single token: กิน
Keyword has not single token: รัก
Keyword has not single token: อร่อย
Keyword has not single token: ตู่
Keyword has not single token: คิดถึง
Keyword has not single token: คิดถึงจังอิ


In [11]:
from torch.nn import CrossEntropyLoss
import torch

def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False) #change to reduction=None
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate and scale weighting
    weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
        axis=[0, 2]
    )
    weights = alpha * (1.0 + weights)
    # Calculate weighted average
    weighted_loss = (loss_per_sample * weights).mean()
    return weighted_loss

In [12]:
from torch.utils.data.dataloader import DataLoader

tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=16 ,shuffle=True)
eval_dataloader  = DataLoader(tokenized_datasets["valid"], batch_size=16)

In [13]:
weight_decay = 0.1


def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [14]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])
            outputs.loss = outputs.loss.reshape(1)
        losses.append(accelerator.gather(outputs.loss))        
    loss = torch.mean(torch.cat(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [15]:
model = GPT2LMHeadModel(config)

In [16]:
from torch.optim import AdamW

optimizer = AdamW(get_grouped_params(model), lr=5e-4)

## Accelerator

Now let’s prepare the model, optimizer, and dataloaders so we can start training:

In [None]:
# !pip install accelerate

In [17]:
from accelerate import Accelerator

accelerator = Accelerator(mixed_precision='fp16')

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [18]:
from transformers import get_scheduler

num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=1_000,
    num_training_steps=num_training_steps,
)

Login to huggingface

In [19]:
from huggingface_hub import notebook_login

notebook_login()

Login successful
Your token has been saved to /home/i7g7-1080ti/.huggingface/token


In [20]:
from huggingface_hub import Repository, get_full_repo_name

model_name = "GPT_Thai"
repo_name = get_full_repo_name(model_name)
repo_name

'Earth1221/GPT_Thai'

In [22]:
import  os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

output_dir = "GPT_Thai"
repo = Repository(output_dir, clone_from=repo_name)

Cloning https://huggingface.co/Earth1221/GPT_Thai into local empty directory.


In [23]:
evaluate()

(11.026223182678223, 61464.99609375)

In [24]:
from tqdm.notebook import tqdm

gradient_accumulation_steps = 8
eval_steps = 5_000

model.train()
completed_steps = 0
for epoch in range(num_train_epochs):
    for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=num_training_steps
    ):
        logits = model(batch["input_ids"]).logits
        loss = keytoken_weighted_loss(batch["input_ids"], logits, keytoken_ids)
        if step % 100 == 0:
            accelerator.print(
                {
                    "steps": completed_steps,
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1
        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress step {step}", blocking=False
                )

  0%|          | 0/28964 [00:00<?, ?it/s]



{'steps': 12, 'loss/train': 97.81797790527344}
{'steps': 24, 'loss/train': 74.01036071777344}
{'steps': 37, 'loss/train': 75.79133605957031}
{'steps': 49, 'loss/train': 71.50773620605469}
{'steps': 62, 'loss/train': 61.85414505004883}
{'steps': 74, 'loss/train': 67.1125259399414}
{'steps': 87, 'loss/train': 55.62742614746094}
{'steps': 99, 'loss/train': 58.50541687011719}
{'steps': 112, 'loss/train': 52.93339538574219}
{'steps': 124, 'loss/train': 47.67146301269531}
{'steps': 137, 'loss/train': 44.43110656738281}
{'steps': 149, 'loss/train': 41.41243362426758}
{'steps': 162, 'loss/train': 40.23468017578125}
{'steps': 174, 'loss/train': 47.499568939208984}
{'steps': 187, 'loss/train': 40.43970489501953}
{'steps': 199, 'loss/train': 46.107749938964844}
{'steps': 212, 'loss/train': 40.13477325439453}
{'steps': 224, 'loss/train': 39.6888427734375}
{'steps': 237, 'loss/train': 38.55208969116211}
{'steps': 249, 'loss/train': 37.895843505859375}
{'steps': 262, 'loss/train': 39.684967041015625

In [38]:
#model.save_pretrained(output_dir)
repo.push_to_hub(commit_message=f"Update model", blocking=False)

In [39]:
# import torch
# from transformers import pipeline

# pipe = pipeline("text-generation", max_length=100, pad_token_id=0, eos_token_id=0, model="Earth1221/GPT_Thai", tokenizer="Earth1221/GPT_Thai")

Downloading:   0%|          | 0.00/966 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/487M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/235 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.19M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/881k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.86M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

In [44]:
# txt = "ลุง"
# print(pipe(txt, num_return_sequences=1)[0]["generated_text"])

ลุงและห้องเก็บสุนัขจัดป่าสัตว์ และแม้ว่า จะทำให้เกิดบุคคลผิดกฎหมายและเป็นเจ้าพนักงานปฏิบัติหรือละเว้น แต่ยังจำคุกอีกนับไม่แน่ชัดจะยื่นตีพิเศษในวันพรุ่งนี้ อีกทั้งได้ให้


## Try many search algorithms

In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer


tokenizer = GPT2Tokenizer.from_pretrained('Earth1221/GPT_Thai')

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("Earth1221/GPT_Thai", pad_token_id=tokenizer.eos_token_id)

In [12]:
input_ids = tokenizer.encode('ลุงตู่', return_tensors='pt')

In [13]:
import torch
torch.manual_seed(0)

### Greedy

In [17]:
greedy_output = model.generate(input_ids, max_length=50)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50265 for open-end generation.


Output:
----------------------------------------------------------------------------------------------------
ลุงตู่,ทั้งนี้ นายวิรัตน์ ก็ยังไม่ได้รับการร้องเรียนจากเจ้าหน้าที่ตำรวจว่า นายวิรัตน์


### Beam search

In [20]:
beam_output = model.generate(
    input_ids,  
    max_length=50, 
    num_beams=2, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50265 for open-end generation.


Output:
----------------------------------------------------------------------------------------------------
ลุงตู่มีลักษณะคล้ายกับแม่น้ำเจ้าพระยาจึงได้มีการขุดค้นพบแม่น้ำเจ้าพระยาจึงได้มี


### Sampling

In [16]:
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50265 for open-end generation.


Output:
----------------------------------------------------------------------------------------------------
ลุงตู่ในอำนาจ และอุดมการณ์เดียวกัน บังเอิญให้กำลังใจ กิ่งอายุ ค่อนข้างชันใดบังเอิญ,ความคืบหน้าล่าส


### Top-p (nucleus) sampling

In [14]:


# set top_k to 50
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_p=0.92, 
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50265 for open-end generation.


Output:
----------------------------------------------------------------------------------------------------
ลุงตู่ เสียหลักชนเส้นเลียบเข้าท่วมเช่นนั้น นายศุกรพลเกษียณอายุกับ น.ส.ชรัฏฐิยานุช โคตรเหล็
