In [6]:
import torch
import re

# Redefine Tokenizer class
class Tokenizer(object):
    def __init__(self, train_filepath=None, test_filepath=None):
        self.train_filepath = train_filepath
        self.chDict = dict()

    def clean_text(self, text):
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def tokenize(self, text):
        text = self.clean_text(text)
        return list(text)

    def char_to_idx(self, text):
        return [self.chDict.get(ch, 0) for ch in text]

# Now load
tokenizer = torch.load("pre-train_model/Tokenizer.pt", weights_only=False)
print("Tokenizer loaded successfully!")

Tokenizer loaded successfully!


In [7]:
import torch.nn as nn

# --- Define only the model architecture (still needed) ---
class BiLSTM_Seg(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(BiLSTM_Seg, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.bilstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        self.classifier = nn.Linear(hidden_dim * 2, 2)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.bilstm(embedded)
        logits = self.classifier(lstm_out)
        return logits

# --- Load model weights ---
vocab_size = len(tokenizer.chDict) + 1
embed_dim = 128
hidden_dim = 128

model = BiLSTM_Seg(vocab_size, embed_dim, hidden_dim)
model.load_state_dict(torch.load("pre-train_model/segmentation_model.pth", map_location="cpu"))
model.eval()

BiLSTM_Seg(
  (embedding): Embedding(124, 128)
  (bilstm): LSTM(128, 128, batch_first=True, bidirectional=True)
  (classifier): Linear(in_features=256, out_features=2, bias=True)
)

In [8]:
def predict_segmentation(text, model, tokenizer):
    model.eval()
    with torch.no_grad():
        cleaned_text = tokenizer.clean_text(text)
        char_indices = tokenizer.char_to_idx(cleaned_text.replace(" ", ""))

        if not char_indices:
            return ""

        input_tensor = torch.tensor(char_indices).unsqueeze(0)
        logits = model(input_tensor)
        predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()

        segmented_words = []
        current_word = []

        for i, char_idx in enumerate(char_indices):
            char = list(cleaned_text.replace(" ", ""))[i]
            label = predictions[i]

            if label == 0 and current_word:
                segmented_words.append("".join(current_word))
                current_word = [char]
            else:
                current_word.append(char)

        if current_word:
            segmented_words.append("".join(current_word))

        return " ".join(segmented_words)

In [9]:
sentence = "នេះជាភាសារបស់ខ្មែរ"
segmented = predict_segmentation(sentence, model, tokenizer)
print("Input:", sentence)
print("Segmented:", segmented)

Input: នេះជាភាសារបស់ខ្មែរ
Segmented: នេះ ជា ភាសា របស់ ខ្មែរ
