In [None]:
!pip install transformers datasets accelerate

In [1]:
# Load file văn bản từ data/train_data.txt
with open("../data/test.vi", encoding="utf-8") as f:
    texts = f.readlines()

# Làm sạch và xử lý
def clean(text): return text.strip().lower()
texts = [clean(line) for line in texts if len(line.strip()) > 0]

# Lưu lại file đã xử lý
with open("cleaned_data.txt", "w", encoding="utf-8") as f:
    for line in texts:
        f.write(line + "\n")


In [None]:
from transformers import GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling

tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token  # thêm pad_token nếu thiếu

train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="cleaned_data.txt",
    block_size=128
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)


In [None]:
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments

model = GPT2LMHeadModel.from_pretrained("distilgpt2")

training_args = TrainingArguments(
    output_dir="../models/gpt2-word-predict",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    save_steps=1000,
    save_total_limit=2,
    logging_steps=200,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

trainer.train()


In [None]:
model.save_pretrained("../models/gpt2-word-predict")
tokenizer.save_pretrained("../models/gpt2-word-predict")


In [None]:
from src.predictor import WordPredictor

predictor = WordPredictor("../models/gpt2-word-predict")
predictor.suggest_next("next m", top_k=5, prefix_filter="m")


In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

class WordPredictor:
    def __init__(self, model_path):
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.model = GPT2LMHeadModel.from_pretrained(model_path)
        self.model.eval()

    def suggest_next(self, prompt, top_k=5, prefix_filter=None):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(input_ids)
            logits = outputs.logits[0, -1, :]
            top_k_ids = logits.topk(top_k * 2).indices
            candidates = [self.tokenizer.decode([i]).strip() for i in top_k_ids]
            if prefix_filter:
                candidates = [c for c in candidates if c.startswith(prefix_filter)]
            return candidates[:top_k]
