# Import Dependencies

In [1]:
!nvidia-smi

Sun Apr 14 13:05:51 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:04:00.0 Off |                  N/A |
| 53%   43C    P8             39W /  390W |      10MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        Off |   00

In [None]:
import torch
import numpy as np
import evaluate

from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import get_scheduler
from transformers import AutoTokenizer
from transformers import Seq2SeqTrainer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq

from transformers import Trainer
from transformers import TrainerCallback
from transformers import TrainingArguments

from pythainlp.tokenize import word_tokenize
from datasets import load_dataset, Dataset, load_from_disk

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print(f"using {n_gpu} {device} device")

using 2 cuda device


# Load Tokenizer

In [None]:
model_checkpoint = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=False)

# Load Model

Encoder - Decoder

In [5]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, torch_dtype=torch.bfloat16)
model.config.use_cache = False

# Load Dataset

In [6]:
tokenized_datasets = load_from_disk('preprocessed_thaisum.hf')

# Evaluation Metric

In [7]:
rouge = evaluate.load("rouge")


def deep_tokenize(word):
    return word_tokenize(word, engine="deepcut")


def compute_metrics(predictions , labels):
    predictions = np.array(predictions)
    labels = np.array(labels)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    print("label =", decoded_labels[0])
    print("predict =", decoded_preds[0])

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True , tokenizer=deep_tokenize)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

# Data Collector & Data Loader

In [8]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [9]:
class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, num_steps=10):
        self.num_steps = num_steps
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step >= self.num_steps:
            return {"should_training_stop": True}
        else:
            return {}

# Fine-tune a pretrained model

- https://huggingface.co/docs/transformers/en/training
- https://huggingface.co/docs/transformers/main/en/trainer

In [10]:
training_args = TrainingArguments(
    output_dir="test_trainer",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    optim="paged_adamw_32bit",
    learning_rate=2e-5,
    num_train_epochs=3,
    lr_scheduler_type="linear",
    # load_best_model_at_end=True
)

In [None]:
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics
)

In [12]:
# NOTE using only 1 gpu instead of 2
limit_n_gpu_to_1 = False
if limit_n_gpu_to_1:
    trainer.args._n_gpu = 1

In [None]:
# resume from latest checkpoint
resume_from_checkpoint = False
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

# Save model

In [None]:
trainer.save_model()