In [1]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)

In [2]:
!pip install py_vncorenlp
!pip install streamlit
!pip install streamlit localtunnel
!pip install imutils
!npm install -g localtunnel

Collecting py_vncorenlp
  Downloading py_vncorenlp-0.1.4.tar.gz (3.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyjnius (from py_vncorenlp)
  Downloading pyjnius-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Downloading pyjnius-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hBuilding wheels for collected packages: py_vncorenlp
  Building wheel for py_vncorenlp (setup.py) ... [?25l[?25hdone
  Created wheel for py_vncorenlp: filename=py_vncorenlp-0.1.4-py3-none-any.whl size=4305 sha256=f6b5756177a34125e3b49315d043cdc4427f0b39b0f12b7a7568e307f57d90d6
  Stored in directory: /root/.cache/pip/wheels/d5/d9/bf/62632cdb007c702a0664091e92a0bb1f18a2fcecbe962d9827
Successfully built py_vncorenlp
Installing collected packages: pyjnius, py_vncorenlp
Successfully installed py

In [9]:
%%writefile app_temp.py

import streamlit as st
import pickle
import numpy as np
from datetime import datetime
import os
from scipy.optimize import fmin_l_bfgs_b
from collections import defaultdict, Counter
from tabulate import tabulate
from time import time
from math import ceil, floor
from itertools import product
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import pandas as pd
import py_vncorenlp

BEGIN = '*'
DEFAULT_CUTOFF_FRACT = 0.3333333

@st.cache_resource
def load_vncorenlp_global():
    """Load VnCoreNLP once globally and cache it"""
    try:
        # Download model if not exists
        py_vncorenlp.download_model(save_dir='/kaggle/working/')
        segmenter = py_vncorenlp.VnCoreNLP(save_dir='/kaggle/working/')
        return segmenter
    except Exception as e:
        st.error(f"Error loading VnCoreNLP: {str(e)}")
        return None

# Initialize VnCoreNLP globally at module level
GLOBAL_SEGMENTER = None

def get_global_segmenter():
    """Get the global segmenter instance"""
    global GLOBAL_SEGMENTER
    if GLOBAL_SEGMENTER is None:
        GLOBAL_SEGMENTER = load_vncorenlp_global()
    return GLOBAL_SEGMENTER

def vn_metaphone(word):
    import unicodedata
    return ''.join(
        c for c in unicodedata.normalize('NFD', word.lower())
        if unicodedata.category(c) != 'Mn'
    )

class Timer:
    def __init__(self, name):
        self.name = name
        self.__start_time = None
        self.__end_time = None
        self.start()

    def start(self):
        self.__start_time = time()

    def stop(self):
        self.__end_time = time()
        self.__get_elapsed__()

    def __get_elapsed__(self):
        elapsed = (self.__end_time - self.__start_time)
        unit = "seconds"
        if elapsed >= 3600:
            unit = "minutes"
            hours = elapsed / 3600
            minutes = hours % 60
            hours = floor(hours)
            print(self.name, "took", str(hours), "hours and", "{0:.2f}".format(minutes), unit, "to complete")
        elif elapsed >= 60:
            minutes = floor(elapsed / 60)
            seconds = elapsed % 60
            print(self.name, "took", str(minutes), "minutes and", "{0:.2f}".format(seconds), unit, "to complete")
        else:
            print(self.name, "took", "{0:.2f}".format(elapsed), unit, "to complete")

class HistoryTuple:
    def __init__(self, sequence_id, sentence, tags, index):
        if index < 0 or index >= len(sentence):
            raise IndexError
        self.index = index
        self.sequence_id = sequence_id
        self.sentence = sentence
        self.tags = tags
        self.t2, self.t1 = self.__get_previous_tags__(tags)

    def __get_previous_tags__(self, tags):
        """function to retrieve 2 previous tags for word in sentence"""
        if len(self.tags) == 0:
            return None, None
        if self.index == 1:
            return BEGIN, tags[self.index-1]
        elif self.index == 0:
            return BEGIN, BEGIN
        else:
            return tags[self.index-2], tags[self.index-1]

    def getWord(self):
        return self.sentence[self.index]

    def getWordTag(self):
        return self.tags[self.index]

    def getT_2(self):
        return self.t2

    def getT_1(self):
        return self.t1

    def getIndex(self):
        return self.index

    def getTupleKey(self):
        return self.t2, self.t1, self.sequence_id, self.index

    def getPossibleTagSet(self, data, cutoff=None, add_common=False):
        """Return tag set which are possible for a given word"""
        full_tag_set_size = data.getTagSetSize()
        if cutoff is None:
            cutoff = ceil(full_tag_set_size * DEFAULT_CUTOFF_FRACT)
        elif cutoff >= full_tag_set_size:
            return data.getTagSet()
        
        word = self.getWord()
        tags_dict = data.getWordDict().get(word, False)
        if tags_dict is False:
            sorted_tags_list = data.getSortedTagsList()
        else:
            sorted_tags_list = sorted(tags_dict, key=tags_dict.get, reverse=True)
        
        if data.isNumberWord(word) and "M" not in sorted_tags_list[:cutoff]:
            sorted_tags_list.insert(0, "M")
        
        remainder = cutoff - len(sorted_tags_list)
        if remainder < 0:
            return tuple(sorted_tags_list[:cutoff])
        elif add_common is True and remainder > 0:
            top_candidate_tags = data.getSortedTagsList()
            sorted_tags_set = set(sorted_tags_list)
            candidate_set = set(top_candidate_tags) - sorted_tags_set
            while remainder > 0 and top_candidate_tags:
                tag_candidate = top_candidate_tags.pop(0)
                if tag_candidate in candidate_set:
                    sorted_tags_list.append(tag_candidate)
                    remainder -= 1
        return tuple(sorted_tags_list)

class VietnameseDataReader:
    def __init__(self, sentences, file_name="data"):
        self.file = file_name
        self.sentences = []
        self.tags = []
        self.tag_dict = defaultdict(int)
        self.word_dict = defaultdict(dict)
        self.word_tag_dict = defaultdict(int)
        self.numbers = 0
        self.cap_no_start = 0
        self.word_suffixes = {}
        self.word_prefixes = {}
        self.__process_sentences__(sentences)
        self.__make_tuples__()
        self.tags_bigrams, self.tags_trigrams = self.__tagsToNgrams__()
        self.sorted_tags_list = sorted(self.tag_dict, key=self.tag_dict.get, reverse=True)

    def __process_sentences__(self, sentences):
        """Process sentences in [(word, tag), ...] format"""
        for sentence in sentences:
            words = []
            tags = []
            for word, tag in sentence:
                words.append(word)
                tags.append(tag)
                
                # Update dictionaries
                self.word_tag_dict[(word, tag)] += 1
                self.tag_dict[tag] += 1
                
                if word not in self.word_dict:
                    self.word_dict[word] = defaultdict(int)
                self.word_dict[word][tag] += 1
                
                # Count numbers and capitals
                if self.isNumberWord(word):
                    self.numbers += 1
                if word and word[0].isupper() and len(words) > 1:
                    self.cap_no_start += 1
            
            self.sentences.append(tuple(words))
            self.tags.append(tuple(tags))

    def __make_tuples__(self):
        self.sentences = tuple(self.sentences)
        self.tags = tuple(self.tags)

    def __tagsToNgrams__(self):
        """Create trigrams and bigrams from data"""
        bigrams = defaultdict(int)
        trigrams = defaultdict(int)
        for tags in self.getTags():
            tags = list(tags)
            for i in range(2):
                tags.insert(0, BEGIN)
            for k in range(2, len(tags)):
                trigrams[tuple(tags[k-2:k+1])] += 1
                bigrams[tuple(tags[k-1:k+1])] += 1
        return bigrams, trigrams

    def __wordsToSuffixes__(self):
        """Create suffixes for all word,tag pairs"""
        suffixes = defaultdict(int)
        for word, tag in self.getWordTagDict():
            for suffix in self.getSuffixesForWord(word):
                suffixes[(suffix, tag)] += 1
        return suffixes

    def __wordsToPrefixes__(self):
        """Create prefixes for all word,tag pairs"""
        prefixes = defaultdict(int)
        for word, tag in self.getWordTagDict():
            for prefix in self.getPrefixesForWord(word):
                prefixes[(prefix, tag)] += 1
        return prefixes

    def getSuffixesForWord(self, word):
        """Generate suffixes for a given word"""
        suffixes = self.word_suffixes.get(word, False)
        if suffixes is not False:
            return suffixes
        suffixes = []
        if word.isalpha():
            boundary = min(5, len(word))
            for i in range(1, boundary):
                suffixes.append(word[-i:])
        suffixes = tuple(suffixes)
        self.word_suffixes[word] = suffixes
        return suffixes

    def getPrefixesForWord(self, word):
        """Generate prefixes for a given word"""
        prefixes = self.word_prefixes.get(word, False)
        if prefixes is not False:
            return prefixes
        prefixes = []
        if word.isalpha():
            boundary = min(5, len(word))
            for i in range(2, boundary):
                prefixes.append(word[:i])
        prefixes = tuple(prefixes)
        self.word_prefixes[word] = prefixes
        return prefixes

    @staticmethod
    def isNumberWord(word):
        if word.isdigit():
            return True
        elif word.isnumeric():
            return True
        elif word.isdecimal():
            return True
        else:
            for char in ('-', ',', '.', '\/'):
                word = word.replace(char, '')
                if word.isdigit():
                    return True
            return False

    def getTagSet(self):
        return tuple(self.tag_dict.keys())

    def getTagSetSize(self):
        return len(self.tag_dict)

    def getWordDict(self):
        return self.word_dict

    def getTagDict(self):
        return self.tag_dict

    def getTagDictSize(self):
        return len(self.tag_dict)

    def getWordDictSize(self):
        return len(self.word_dict)

    def getSentences(self):
        return self.sentences

    def getTags(self):
        return self.tags

    def getSentencesSize(self):
        return len(self.sentences)

    def getTagsSize(self):
        return len(self.tags)

    def getSentenceByIndex(self, index):
        return self.sentences[index]

    def getTagsByIndex(self, index):
        return self.tags[index]

    def getWordTagDict(self):
        return self.word_tag_dict

    def getSortedTagsList(self):
        return self.sorted_tags_list.copy()

    def getTopNTagsForWord(self, word, n):
        tags_dict = self.getWordDict().get(word, False)
        if tags_dict is False:
            return self.getTopNTags(n)
        sorted_tags = sorted(tags_dict, key=tags_dict.get, reverse=True)
        if n == 1:
            return sorted_tags[0] if sorted_tags else self.sorted_tags_list[0]
        elif n >= len(sorted_tags):
            return sorted_tags
        else:
            return sorted_tags[:n]

    def getTopNTags(self, n):
        if n == 1:
            return self.sorted_tags_list[0]
        elif n >= len(self.sorted_tags_list):
            return self.sorted_tags_list
        else:
            return self.sorted_tags_list[:n]

    def getNumbers(self):
        return self.numbers

    def getCapNoStart(self):
        return self.cap_no_start

    def getCapStart(self):
        return self.getSentencesSize()

class VietnameseFeaturesFactory:
    def __init__(self, data, cutoff=0):
        self.data = data
        self.type = "vietnamese_optimized"
        self._cutoff = cutoff
        self._features_index = {}
        self.histories_dict = {}
        self.null_histories_set = set()
        
        # Từ đơn âm tiết hay đa âm tiết
        self.syllable_info = self.__analyze_syllables__()
        
        self.__generateFeaturesIndex__()

    def __analyze_syllables__(self):
        syllable_info = {}
        for sentence in self.data.getSentences():
            for word in sentence:
                # Đếm số âm tiết bằng cách đếm số dấu gạch dưới + 1
                if '_' in word:
                    # Từ có dấu gạch dưới: "bóng_đá" = 2 âm tiết, "học_sinh_viên" = 3 âm tiết
                    syllable_count = word.count('_') + 1
                else:
                    # Từ đơn âm tiết: "tôi", "đẹp", "nhà" = 1 âm tiết
                    syllable_count = 1
                syllable_info[word] = syllable_count
        return syllable_info

    def getSyllablePrefixes(self, word):
        """Return all syllable-level prefixes for a word (joined by '_'). E.g. 'học_sinh_viên' -> ['học', 'học_sinh']"""
        if '_' not in word:
            return []
        sylls = word.split('_')
        return ['_'.join(sylls[:i]) for i in range(1, len(sylls))]
    
    def getSyllableSuffixes(self, word):
        """Return all syllable-level suffixes for a word (joined by '_'). E.g. 'học_sinh_viên' -> ['viên', 'sinh_viên']"""
        if '_' not in word:
            return []
        sylls = word.split('_')
        return ['_'.join(sylls[i:]) for i in range(1, len(sylls))]

    def getFeaturesIndices(self, tag, history, in_data=True):
        """Lấy feature indices tối ưu cho tiếng Việt"""
        indices = []
        word = history.getWord()
        position = history.getIndex()
        sentence = history.sentence
        
        # 1. f100: (Word,Tag) pair
        if in_data:
            feature_idx = self._features_index.get(("f100", (word, tag)), False)
            if feature_idx is not False:
                indices.append(feature_idx)

        # 2. Phonetic (Metaphone) feature
        metaphone = vn_metaphone(word)
        feature_idx = self._features_index.get(("fMetaphone", (metaphone, tag)), False)
        if feature_idx is not False:
            indices.append(feature_idx)
        
        # 3. f103: Trigram Tags
        feature_idx = self._features_index.get(("f103", (history.getT_2(), history.getT_1(), tag)), False)
        if feature_idx is not False:
            indices.append(feature_idx)
        
        # 4. Bigram Tags
        feature_idx = self._features_index.get(("f104", (history.getT_1(), tag)), False)
        if feature_idx is not False:
            indices.append(feature_idx)
        
        # 5. Window features: W_{i-2}, W_{i-1}, W_{i+1}, W_{i+2}
        for offset in [-2, -1, 1, 2]:
            idx = position + offset
            if 0 <= idx < len(sentence):
                context_word = sentence[idx]
                feature_idx = self._features_index.get(("fWindow", (offset, context_word, tag)), False)
                if feature_idx is not False:
                    indices.append(feature_idx)
        
        # 6. Word pair features: (W_{i-1}, W_i), (W_i, W_{i+1})
        if position > 0:
            prev_word = sentence[position - 1]
            feature_idx = self._features_index.get(("fWordPair", (prev_word, word, tag)), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        if position < len(sentence) - 1:
            next_word = sentence[position + 1]
            feature_idx = self._features_index.get(("fWordPair", (word, next_word, tag)), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        # 7. Punctuation feature
        if tag == 'CH':
            feature_idx = self._features_index.get(("fPunct", tag), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        # 8. Number feature
        if tag == 'M':
            feature_idx = self._features_index.get(("fNum", tag), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        # 9. Quantifier feature
        if tag == 'L':
            feature_idx = self._features_index.get(("fQuantifier", tag), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        # 10. Capitalization features
        if word and word[0].isupper():
            if position == 0:
                feature_idx = self._features_index.get(("fCapStart", tag), False)
                if feature_idx is not False:
                    indices.append(feature_idx)
            else:
                feature_idx = self._features_index.get(("fCapNoStart", tag), False)
                if feature_idx is not False:
                    indices.append(feature_idx)
        
        # 11. Syllable count feature
        syllable_count = self.syllable_info.get(word, 1)
        feature_idx = self._features_index.get(("fSyllable", (syllable_count, tag)), False)
        if feature_idx is not False:
            indices.append(feature_idx)
        
        # 12. Word length feature
        length_category = "short" if len(word) <= 3 else "medium" if len(word) <= 5 else "long"
        feature_idx = self._features_index.get(("fLength", (length_category, tag)), False)
        if feature_idx is not False:
            indices.append(feature_idx)

        # 13. SyllablePrefix
        for prefix in self.getSyllablePrefixes(word):
            feature_idx = self._features_index.get(("fSyllablePrefix", (prefix, tag)), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        # 14. SyllableSuffix
        for suffix in self.getSyllableSuffixes(word):
            feature_idx = self._features_index.get(("fSyllableSuffix", (suffix, tag)), False)
            if feature_idx is not False:
                indices.append(feature_idx)
        
        # 15. Position in sentence features
        sent_len = len(sentence)
        if position == 0:
            pos_feature = "first"
        elif position == sent_len - 1:
            pos_feature = "last"
        elif position / sent_len < 0.3:
            pos_feature = "early"
        elif position / sent_len > 0.7:
            pos_feature = "late"
        else:
            pos_feature = "middle"
        
        feature_idx = self._features_index.get(("fPosition", (pos_feature, tag)), False)
        if feature_idx is not False:
            indices.append(feature_idx)
        
        return indices

    def __generateFeaturesIndex__(self):
        """Generate features index cho tất cả features"""
        feature_names = [
            "f100", "f103", "f104", "fWindow", "fWordPair", "fPunct", 
            "fNum", "fQuantifier", "fCapStart", "fCapNoStart", 
            "fSyllable", "fLength", "fPosition",
            "fSyllablePrefix", "fSyllableSuffix", "fMetaphone"
        ]
        
        # Build feature dictionaries
        feature_dicts = {}
        
        # Existing features
        feature_dicts["f100"] = self.data.getWordTagDict()
        feature_dicts["f103"] = self.data.tags_trigrams
        feature_dicts["f104"] = self.data.tags_bigrams
        
        # New features
        feature_dicts.update(self.__build_new_feature_dicts__())
        
        # Generate indices
        keys = []
        for name in feature_names:
            if name in feature_dicts:
                features = []
                for feature in feature_dicts[name].keys():
                    if feature_dicts[name].get(feature) > self._cutoff:
                        features.append((name, feature))
                keys.extend(features)
        
        for i, key in enumerate(keys):
            self._features_index[key] = i
        
        self.features_list = tuple(keys)
        self._features_vector_length = len(keys)

    def __build_new_feature_dicts__(self):
        new_dicts = defaultdict(lambda: defaultdict(int))
        
        for k in range(self.data.getSentencesSize()):
            sentence = self.data.getSentenceByIndex(k)
            tags = self.data.getTagsByIndex(k)
            
            for i, (word, tag) in enumerate(zip(sentence, tags)):
                # Window features
                for offset in [-2, -1, 1, 2]:
                    idx = i + offset
                    if 0 <= idx < len(sentence):
                        context_word = sentence[idx]
                        new_dicts["fWindow"][(offset, context_word, tag)] += 1

                metaphone = vn_metaphone(word)
                new_dicts["fMetaphone"][(metaphone, tag)] += 1
                
                # Word pair features
                if i > 0:
                    prev_word = sentence[i - 1]
                    new_dicts["fWordPair"][(prev_word, word, tag)] += 1
                
                if i < len(sentence) - 1:
                    next_word = sentence[i + 1]
                    new_dicts["fWordPair"][(word, next_word, tag)] += 1
                
                # Punctuation feature
                if tag == 'CH':
                    new_dicts["fPunct"][tag] += 1
                
                # Number feature
                if tag == 'M':
                    new_dicts["fNum"][tag] += 1
                
                # Quantifier feature
                if tag == 'L':
                    new_dicts["fQuantifier"][tag] += 1
                    
                # fSyllablePrefix
                for prefix in self.getSyllablePrefixes(word):
                    new_dicts["fSyllablePrefix"][(prefix, tag)] += 1
                
                # fSyllableSuffix
                for suffix in self.getSyllableSuffixes(word):
                    new_dicts["fSyllableSuffix"][(suffix, tag)] += 1
                
                # Syllable count
                syllable_count = self.syllable_info.get(word, 1)
                new_dicts["fSyllable"][(syllable_count, tag)] += 1
                
                # Word length
                length_category = "short" if len(word) <= 3 else "medium" if len(word) <= 5 else "long"
                new_dicts["fLength"][(length_category, tag)] += 1
                
                # Position in sentence
                sent_len = len(sentence)
                if i == 0:
                    pos_feature = "first"
                elif i == sent_len - 1:
                    pos_feature = "last"
                elif i / sent_len < 0.3:
                    pos_feature = "early"
                elif i / sent_len > 0.7:
                    pos_feature = "late"
                else:
                    pos_feature = "middle"
                
                new_dicts["fPosition"][(pos_feature, tag)] += 1
                
                # Capitalization features
                if word and word[0].isupper():
                    if i == 0:
                        new_dicts["fCapStart"][tag] += 1
                    else:
                        new_dicts["fCapNoStart"][tag] += 1
        
        return dict(new_dicts)

    def getFeaturesVectorLength(self):
        return self._features_vector_length
    
    def getCutoffParameter(self):
        return self._cutoff
    
    def getEmpiricalCounts(self):
        """Get empirical counts vector"""
        empirical_counts = np.zeros(self.getFeaturesVectorLength(), dtype=float)
        
        for k in range(self.data.getSentencesSize()):
            sentence = self.data.getSentenceByIndex(k)
            tags = self.data.getTagsByIndex(k)
            
            for i in range(len(sentence)):
                history = HistoryTuple(k, sentence, tags, i)
                features = self.getFeaturesIndices(tags[i], history, True)
                for feature_idx in features:
                    empirical_counts[feature_idx] += 1.0
        
        return empirical_counts


class ViterbiAlgorithm:
    """Viterbi Algorithm for Vietnamese POS tagging"""
    def __init__(self, sequence_id, sentence, sentence_tags, model, cutoff=None):
        self.sequence_id = sequence_id
        self.sentence = sentence
        self.sentence_tags = sentence_tags
        self.data = model.data
        if cutoff is None:
            self.cutoff = self.data.getTagSetSize()
        else:
            self.cutoff = cutoff
        self.tags_set = self.data.getTagSet()
        self.prob_func = model.probability
        self.weights = model.getWeights()
        self.pi = {(-1, BEGIN, BEGIN): 1.0}
        self.bp = {}
        self.tag_sequence = []

    def run(self):
        """Main Viterbi algorithm"""
        sentence_length = len(self.sentence)
        for k in range(sentence_length):
            tag_pairs = tuple(product(self.__calc_possible_tags_set__(k-1), 
                                    self.__calc_possible_tags_set__(k)))
            for u, v in tag_pairs:
                key = (k, u, v)
                self.pi[key], self.bp[k] = self.__calc_max_probability__(key)
                if self.pi[key] == 0.0000:
                    self.bp[k] = self.data.getTopNTagsForWord(self.sentence[k], 1)
        
        self.bp[sentence_length], self.bp[sentence_length+1] = self.__calc_last_tags__(sentence_length)
        for k in range(sentence_length):
            self.tag_sequence.append(self.bp.get(k+2, False))

    def __calc_possible_tags_set__(self, index):
        """Return possible tag set for a given position in the sentence"""
        if index < 0:
            return (BEGIN,)
        return HistoryTuple(self.sequence_id, self.sentence, self.sentence_tags, index).getPossibleTagSet(self.data, self.cutoff, add_common=False)

    def __calc_max_probability__(self, key):
        """Calculate maximum probability for each iteration"""
        k = key[0]
        u = key[1]
        v = key[2]
        if k < 0:
            return 1.0, BEGIN
        max_pi = 0.00000
        max_bp = None
        possible_tags_set = self.__calc_possible_tags_set__(k-2)
        for t in possible_tags_set:
            new_key = (k-1, t, u)
            history = HistoryTuple(self.sequence_id, self.sentence, self.sentence_tags, k)
            pi_value = self.pi.get(new_key, 0.00000) * self.prob_func(v, history, self.weights)
            if pi_value >= max_pi:
                max_pi = pi_value
                max_bp = t
        return max_pi, max_bp

    def __calc_last_tags__(self, sentence_length):
        """Return last 2 tags in the sequence"""
        max_pi = 0.0
        max_bp = ()
        tag_pairs = tuple(product(self.__calc_possible_tags_set__(sentence_length-2), 
                                self.__calc_possible_tags_set__(sentence_length-1)))
        for u, v in tag_pairs:
            key = (sentence_length-1, u, v)
            pi_value = self.pi.get(key, None)
            if pi_value and pi_value >= max_pi:
                max_pi = pi_value
                max_bp = (u, v)
        return max_bp

    def getBestTagSequence(self):
        return tuple(self.tag_sequence)

class MEMM:
    """MEMM model for Vietnamese POS tagging"""
    def __init__(self, feature_factory, regularizer=0, pretrained_weights=None):
        self.data = feature_factory.data
        self.feature_factory = feature_factory
        self.regularizer = float(regularizer)
        self.cache = self.getTrainedWeightsCacheName()
        self.weights = self.__initializeWeights__(pretrained_weights)
        self.train_results = None
        self.predictions = {}
        self.correct_tags = defaultdict(int)
        self.wrong_tags = defaultdict(int)
        self.wrong_tag_pairs = defaultdict(int)
        self.wrong_tags_dicts = {}

    def __initializeWeights__(self, pretrained_weights):
        """Initialize model weights"""
        weights_vector_length = self.feature_factory.getFeaturesVectorLength()
        weights = np.zeros(weights_vector_length, dtype=float)
        if pretrained_weights is True:
            weights = self.loadTrainedWeights(self.getTrainedWeightsCacheName())
        elif pretrained_weights is not None and type(pretrained_weights) is np.ndarray and len(pretrained_weights) == weights_vector_length:
            weights = pretrained_weights
        return weights

    def getWeights(self):
        return self.weights

    def getFeatures(self, tag, history, in_data=False):
        """Get feature instances indices for given tag and HistoryTuple"""
        history_key = (tag, history.getTupleKey())
        if history_key in self.feature_factory.null_histories_set:
            return []
        feature = self.feature_factory.histories_dict.get(history_key, None)
        if feature is None:
            feature = self.feature_factory.getFeaturesIndices(tag, history, in_data)
            if len(feature) == 0:
                self.feature_factory.null_histories_set.add(history_key)
        return feature

    def calc_dot_product(self, features, weights):
        """Calculate dot product between feature and weights vectors"""
        total = 0.0
        for index in features:
            total += weights[index]
        return total

    def calcDenominatorBatch(self, history, weights, cutoff=None):
        """Calculate sum in denominator of probability calculation"""
        full_tag_set_size = self.data.getTagSetSize()
        tag_set = history.getPossibleTagSet(self.data, cutoff, add_common=True)
        remainder = float(full_tag_set_size) - len(tag_set)
        total = 0.0
        for tag in tag_set:
            features = self.getFeatures(tag, history, False)
            if len(features) == 0:
                temp = 1.0
            else:
                temp = np.exp(self.calc_dot_product(features, weights), dtype=float)
            total += temp
        if remainder > 0:
            total += remainder
        if total == 0.0:
            total = 0.0001
        return total

    def calcNominator(self, features, weights):
        """Calculate nominator in probability calculation"""
        if len(features) == 0:
            nominator = 1.0
        else:
            product = self.calc_dot_product(features, weights)
            if product == 0.0:
                nominator = 1.0
            else:
                nominator = np.exp(product, dtype=float)
        return nominator

    def probability(self, tag, history, weights, features=None):
        """Calculate probability of specific tag given history"""
        if features is None:
            features = self.getFeatures(tag, history, True)
        nominator = self.calcNominator(features, weights)
        denominator = self.calcDenominatorBatch(history, weights)
        return float(nominator/denominator)

    def calc_loss(self, weights):
        """Calculate loss function value over entire dataset"""
        timer = Timer("Loss Calculation")
        features_sum = 0.0
        denominators_sum = 0.0
        for k in range(self.data.getSentencesSize()):
            sentence = self.data.getSentenceByIndex(k)
            tags = self.data.getTagsByIndex(k)
            for i in range(len(sentence)):
                history = HistoryTuple(k, sentence, tags, i)
                features_sum += self.calc_dot_product(self.getFeatures(tags[i], history, True), weights)
                denominators_sum += np.log(self.calcDenominatorBatch(history, weights, self.data.getTagSetSize()), dtype=float)
        if self.regularizer == 1.0:
            regularization_sum = np.sum(np.power(weights, 2, dtype=float), dtype=float) / 2.0
        elif self.regularizer != 0.0:
            regularization_sum = self.regularizer * np.sum(np.power(weights, 2, dtype=float), dtype=float) / 2.0
        else:
            regularization_sum = 0.0
        total = regularization_sum + denominators_sum - features_sum
        timer.stop()
        print("Loss:", total)
        return total

    def calcExpectedCountsDict(self, weights):
        """Calculate expected counts"""
        dictionary = defaultdict(float)
        
        # Process all sentences sequentially
        for i in range(self.data.getSentencesSize()):
            sentence = self.data.getSentenceByIndex(i)
            tags = self.data.getTagsByIndex(i)
            for j in range(len(sentence)):
                history = HistoryTuple(i, sentence, tags, j)
                self.calcExpectedCountsBatchInternal(history, weights, dictionary)
        
        return Counter(dictionary)

    def calcExpectedCountsBatchInternal(self, history, weights, dictionary):
        """Internal function for expected counts calculation"""
        cutoff = self.data.getTagSetSize()
        tag_set = history.getPossibleTagSet(self.data, cutoff, add_common=True)
        for tag in tag_set:
            features = self.getFeatures(tag, history, False)
            if len(features) == 0:
                continue
            probability = self.probability(tag, history, weights, features)
            for index in features:
                dictionary[index] += probability

    def calc_gradient(self, weights):
        """Calculate gradient vector over entire dataset"""
        timer = Timer("Gradient Calculation")
        empirical_counts = self.feature_factory.getEmpiricalCounts()
        expected_counts_dict = self.calcExpectedCountsDict(weights)
        expected_counts = self.calcExpectedCountsVector(expected_counts_dict)
        if self.regularizer == 1.0:
            regularization_counts = weights
        elif self.regularizer != 0.0:
            regularization_counts = self.regularizer * weights
        else:
            regularization_counts = 0.0
        total = regularization_counts + expected_counts - empirical_counts
        timer.stop()
        print("Average Gradient value:", np.mean(total))
        return total

    def calcExpectedCountsVector(self, dictionary):
        """Convert ExpectedCounts dictionary to numpy vector"""
        indexes = dictionary.keys()
        vector = np.zeros(self.feature_factory.getFeaturesVectorLength(), dtype=float)
        for index in indexes:
            vector[index] = dictionary.get(index, 0.0)
        return vector

    def fit(self, max_iter=100, tolerance=0.001, factr=1e12, save=True):
        """Train the model using L-BFGS-B"""
        timer = Timer("Training")
        weights, loss, result = fmin_l_bfgs_b(self.calc_loss, self.weights, self.calc_gradient, pgtol=tolerance, maxiter=max_iter, factr=factr)
        if result.get("warnflag", False) != 0:
            print("Warning - gradient didn't converge within", max_iter, "iterations")
        result['loss'] = loss
        print(result)
        self.train_results = result
        self.weights = weights
        timer.stop()
        if save:
            import os
            os.makedirs('./cache', exist_ok=True)
            with open(self.getTrainedWeightsCacheName(), 'wb') as cache:
                pickle.dump({'weights': self.weights, 'train_results': self.train_results}, cache)

    def predictSequential(self, data, cutoff):
        """Run predictions sequentially instead of using threads"""
        timer = Timer("Predicting " + str(data.getSentencesSize()) + " sentences")
        predictions = {}
        
        for i in range(data.getSentencesSize()):
            sentence = data.getSentenceByIndex(i)
            tags = data.getTagsByIndex(i)
            viterbi = ViterbiAlgorithm(i, sentence, tags, self, cutoff)
            viterbi.run()
            predictions[i] = viterbi.getBestTagSequence()
        
        timer.stop()
        return predictions

    def predict(self, data, cutoff=3):
        """Predict tags for entire dataset"""
        timer = Timer("Inference")
        self.predictions[data.file] = self.predictSequential(data, cutoff)
        timer.stop()

    def evaluate(self, data, verbose=False):
        """Evaluate model predictions vs truth"""
        assert data.getTagsSize() == len(self.predictions.get(data.file, [])), "Predictions and truth are not the same length!"
        timer = Timer("Evaluation")
        accuracies = []
        for i in range(data.getTagsSize()):
            truth = data.getTagsByIndex(i)
            prediction = self.predictions.get(data.file).get(i, False)
            accuracies.append(self.accuracy(truth, prediction, verbose))
        avg = np.mean(accuracies)
        minimum = np.min(accuracies)
        maximum = np.max(accuracies)
        med = np.median(accuracies)
        print("Results for", data.file)
        print("Total Average Accuracy:", avg)
        print("Minimal Accuracy:", minimum)
        print("Maximal Accuracy:", maximum)
        print("Median Accuracy:", med)
        self.confusionTable(data.file)
        self.confusionMatrix(data.file)
        timer.stop()
        return data.file, avg, minimum, maximum, med

    def accuracy(self, truth, predictions, verbose=False):
        """Calculate accuracy for a given sentence"""
        assert len(truth) == len(predictions), "Predictions and truth are not the same length!"
        correct = 0
        for i in range(len(truth)):
            key = truth[i]
            subkey = predictions[i]
            if truth[i] == predictions[i]:
                correct += 1
                self.correct_tags[key] += 1
            else:
                self.wrong_tags[key] += 1
                self.wrong_tag_pairs[(key, subkey)] += 1
                if self.wrong_tags_dicts.get(key, False) is False:
                    self.wrong_tags_dicts[key] = defaultdict(int)
                self.wrong_tags_dicts[key][subkey] += 1
                if verbose:
                    print("Mistake in index", i, "(truth, prediction): ", key, subkey)
        result = float(correct) / len(truth)
        if verbose:
            print("Accuracy:", result)
        return result

    def confusionMatrix(self, file, n=10):
        """Produce Confusion Matrix for top n wrong tags"""
        top_wrong_tags = sorted(self.wrong_tags, key=self.wrong_tags.get, reverse=True)[:n]
        header = top_wrong_tags
        rows = []
        for truth in top_wrong_tags:
            columns = [truth]
            for prediction in top_wrong_tags:
                if truth == prediction:
                    columns.append(self.correct_tags.get(truth))
                else:
                    columns.append(self.wrong_tag_pairs.get((truth, prediction)))
            rows.append(columns)
        print("Confusion Matrix for " + self.feature_factory.type + " model on " + file + " dataset")
        header.insert(0, "Truth \ Predicted")
        print(tabulate(rows, headers=header))

    def confusionTable(self, file, n=10):
        """Produce Confusion Table for top n wrong tags"""
        top_wrong_tags = sorted(self.wrong_tag_pairs, key=self.wrong_tag_pairs.get, reverse=True)[:n]
        header = ("Correct Tag", "Model's Tag", "Frequency")
        rows = []
        for truth, prediction in tuple(top_wrong_tags):
            freq = self.wrong_tag_pairs.get((truth, prediction))
            rows.append((truth, prediction, freq))
        print("Confusion Table for " + self.feature_factory.type + " model on " + file + " dataset")
        print(tabulate(rows, headers=header))

    def getTrainedWeightsCacheName(self):
        """Get cache file name according to model parameters"""
        prefix = "./cache/"
        parameters = "data-" + str(self.data.getSentencesSize()) + "_features-" + self.feature_factory.type +"_weightSize-"\
                     + str(self.feature_factory.getFeaturesVectorLength()) + "_cutoff-" + str(self.feature_factory.getCutoffParameter()) \
                     + "_regularizer-" + str(self.regularizer)
        suffix = "_trained_weights.pkl"
        return prefix + parameters + suffix

    def loadTrainedWeights(self, file):
        """Load pretrained weights from cache file"""
        with open(file, 'rb') as cache:
            trained = pickle.load(cache)
            weights = trained.get('weights')
        return weights

    def evaluate_with_sklearn(self, data, verbose=False):
        """Evaluate model using sklearn metrics"""
        assert data.getTagsSize() == len(self.predictions.get(data.file, [])), "Predictions and truth are not the same length!"
        
        # Flatten all true tags and predictions
        y_true = []
        y_pred = []
        
        for i in range(data.getTagsSize()):
            truth = data.getTagsByIndex(i)
            prediction = self.predictions.get(data.file).get(i, [])
            
            y_true.extend(truth)
            y_pred.extend(prediction)
        
        overall_accuracy = accuracy_score(y_true, y_pred)
        
        report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
        
        df_report = pd.DataFrame(report).transpose()
        
        # Print results
        print(f"\n=== Sklearn Evaluation Results for MEMM ===")
        print(f"Overall Accuracy: {overall_accuracy:.4f}")
        print("\nDetailed Classification Report:")
        print(df_report.round(4))
        
        # Create a summary table for main metrics
        main_tags = [tag for tag in df_report.index if tag not in ['accuracy', 'macro avg', 'weighted avg']]
        summary_data = []
        
        for tag in main_tags:
            summary_data.append({
                'Tag': tag,
                'Precision': df_report.loc[tag, 'precision'],
                'Recall': df_report.loc[tag, 'recall'],
                'F1-Score': df_report.loc[tag, 'f1-score'],
                'Support': int(df_report.loc[tag, 'support'])
            })
        
        summary_df = pd.DataFrame(summary_data)
        
        # Add overall metrics
        overall_metrics = pd.DataFrame([
            {
                'Tag': 'macro avg',
                'Precision': df_report.loc['macro avg', 'precision'],
                'Recall': df_report.loc['macro avg', 'recall'],
                'F1-Score': df_report.loc['macro avg', 'f1-score'],
                'Support': int(df_report.loc['macro avg', 'support'])
            },
            {
                'Tag': 'weighted avg',
                'Precision': df_report.loc['weighted avg', 'precision'],
                'Recall': df_report.loc['weighted avg', 'recall'],
                'F1-Score': df_report.loc['weighted avg', 'f1-score'],
                'Support': int(df_report.loc['weighted avg', 'support'])
            }
        ])
        
        final_summary = pd.concat([summary_df, overall_metrics], ignore_index=True)
        
        print("\n=== Summary Table ===")
        print(final_summary.round(4).to_string(index=False))
        
        # Show confusion matrix for top tags
        if verbose:
            from sklearn.metrics import confusion_matrix
            import seaborn as sns
            import matplotlib.pyplot as plt
            
            # Get top 10 most frequent tags
            tag_counts = pd.Series(y_true).value_counts()
            top_tags = tag_counts.head(10).index.tolist()
            
            y_true_filtered = []
            y_pred_filtered = []
            
            for true_tag, pred_tag in zip(y_true, y_pred):
                if true_tag in top_tags and pred_tag in top_tags:
                    y_true_filtered.append(true_tag)
                    y_pred_filtered.append(pred_tag)
            
            if len(y_true_filtered) > 0 and len(y_pred_filtered) > 0:
                cm = confusion_matrix(y_true_filtered, y_pred_filtered, labels=top_tags)
                
                plt.figure(figsize=(12, 10))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                           xticklabels=top_tags, yticklabels=top_tags)
                plt.title('Confusion Matrix - Top 10 Tags')
                plt.ylabel('True Label')
                plt.xlabel('Predicted Label')
                plt.show()
            else:
                print("Not enough data for confusion matrix visualization")
        
        return {
            'overall_accuracy': overall_accuracy,
            'classification_report': report,
            'summary_df': final_summary,
            'y_true': y_true,
            'y_pred': y_pred
        }

class Tagger:
    
    def __init__(self, model_path, cutoff=3):
        self.cutoff = cutoff
        self.memm = self.load_complete_model(model_path)
        
        # Use global segmenter instead of loading new one
        self.segmenter = get_global_segmenter()
    
    @staticmethod
    def load_complete_model(filepath):
        """Load complete model from file"""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        # Reconstruct the model
        model = MEMM(model_data['feature_factory'], 
                    regularizer=model_data['model_metadata']['regularizer'])
        
        # Set all the loaded components
        model.weights = model_data['weights']
        model.data = model_data['data_reader']
        model.feature_factory = model_data['feature_factory']
        model.train_results = model_data.get('train_results')
        
        return model
    
    def segment_text(self, text):
        """Always segment text using VnCoreNLP"""
        if self.segmenter is None:
            st.error("VnCoreNLP not available. Cannot process text.")
            return []
        
        try:
            segmented = self.segmenter.word_segment(text)
            return segmented
        except Exception as e:
            st.error(f"Segmentation error: {str(e)}")
            return []
    
    def tag_single_word(self, word):
        """Tag a single word using VnCoreNLP segmentation"""
        # Always use VnCoreNLP to segment the word (in case it's compound)
        segmented_sentences = self.segment_text(word)
        
        if not segmented_sentences:
            return 'Unknown'
        
        # Take the first segmented sentence
        segmented_word = segmented_sentences[0].strip()
        
        # Split into tokens
        tokens = segmented_word.split()
        
        if not tokens:
            return 'Unknown'
        
        # If it's just one token, add a period for context
        if len(tokens) == 1:
            sentence_tokens = [tokens[0], '.']
        else:
            sentence_tokens = tokens
        
        sentence_tuple = tuple(sentence_tokens)
        dummy_tags = ['O'] * len(sentence_tuple)
        
        # Run Viterbi algorithm
        viterbi = ViterbiAlgorithm(0, sentence_tuple, dummy_tags, self.memm, cutoff=self.cutoff)
        viterbi.run()
        
        # Return only the tag for the first word
        tags = list(viterbi.getBestTagSequence())
        return tags[0] if tags else 'Unknown'
    
    def tag_sentence(self, sentence):
        """Tag a sentence using VnCoreNLP segmentation"""
        # Always use VnCoreNLP to segment the sentence
        segmented_sentences = self.segment_text(sentence)
        
        if not segmented_sentences:
            return []
        
        # Process each segmented sentence
        all_results = []
        for seg_sentence in segmented_sentences:
            words = seg_sentence.split()
            if words:  # Only process non-empty sentences
                tags = self._tag_word_list(words)
                all_results.append({
                    'original_input': sentence,
                    'segmented_sentence': seg_sentence,
                    'words': words,
                    'tags': tags
                })
        
        return all_results
    
    def _tag_word_list(self, words):
        """Tag a list of words"""
        if not words:
            return []
            
        sentence_tuple = tuple(words)
        dummy_tags = ['O'] * len(sentence_tuple)
        
        # Run Viterbi algorithm
        viterbi = ViterbiAlgorithm(0, sentence_tuple, dummy_tags, self.memm, cutoff=self.cutoff)
        viterbi.run()
        
        return list(viterbi.getBestTagSequence())



import re
import math

def load_tagged_sentences(path):
    """Load sentences using CH tag, but only sentence-ending punctuation"""
    sentence = []
    sentence_ending_punctuation = {'.', '!', '?', '...', '..'} 
    
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue 
                
            try:
                word, tag = re.split(r'\s+', line)
                sentence.append((word, tag))
                
                if tag == 'CH' and word in sentence_ending_punctuation:
                    if sentence:
                        yield sentence
                        sentence = []
                        
            except ValueError:
                # Handle malformed lines
                continue
    
    # Yield final sentence if exists
    if sentence:
        yield sentence

def train_trigram_hmm(sentences_path, min_count=1):
    # Build vocabulary first
    word_counts = Counter()
    for sentence in load_tagged_sentences(sentences_path):
        for word, _ in sentence:
            word_counts[word] += 1
    
    vocab = {word for word, count in word_counts.items() if count >= min_count}
    vocab.update(['<UNK>', '<s>', '</s>'])
    
    # Train model
    emission_counts = defaultdict(Counter)
    trigram_counts = Counter()
    bigram_counts = Counter()
    unigram_counts = Counter()
    tagset = set(['<s>', '</s>'])

    for sentence in load_tagged_sentences(sentences_path):
        # Replace rare words with <UNK>
        processed_sentence = []
        for word, tag in sentence:
            word = word if word in vocab else '<UNK>'
            processed_sentence.append((word, tag))
            tagset.add(tag)

        words = ['<s>', '<s>'] + [w for w, t in processed_sentence] + ['</s>']
        tags = ['<s>', '<s>'] + [t for w, t in processed_sentence] + ['</s>']

        for word, tag in processed_sentence:
            emission_counts[tag][word] += 1

        for i in range(len(tags)):
            unigram_counts[tags[i]] += 1
        for i in range(len(tags)-1):
            bigram_counts[(tags[i], tags[i+1])] += 1
        for i in range(len(tags)-2):
            trigram_counts[(tags[i], tags[i+1], tags[i+2])] += 1

    return emission_counts, trigram_counts, bigram_counts, unigram_counts, tagset, vocab

class ProbabilityCache:
    def __init__(self, emission_counts, trigram_counts, bigram_counts, tagset, vocab):
        self.emission_counts = emission_counts
        self.trigram_counts = trigram_counts
        self.bigram_counts = bigram_counts
        self.tagset = tagset
        self.vocab = vocab
        self.tagset_size = len(tagset)
        self.vocab_size = len(vocab)
        
        # Pre-compute tag totals
        self.tag_totals = {}
        for tag in emission_counts:
            self.tag_totals[tag] = sum(emission_counts[tag].values())
    
    def trigram_prob(self, t1, t2, t3):
        numerator = self.trigram_counts[(t1, t2, t3)] + 1
        denominator = self.bigram_counts[(t1, t2)] + self.tagset_size
        return math.log(numerator / denominator)
    
    def emission_prob(self, tag, word):
        numerator = self.emission_counts[tag][word] + 1
        denominator = self.tag_totals.get(tag, 0) + self.vocab_size
        return math.log(numerator / denominator)

def viterbi_decode(sentence, prob_cache):
    n = len(sentence)
    tagset = prob_cache.tagset
    vocab = prob_cache.vocab
    
    V = [{} for _ in range(n + 1)]
    backpointer = [{} for _ in range(n + 1)]
    
    V[0][('<s>', '<s>')] = 0.0
    
    for i in range(n):
        word = sentence[i] if sentence[i] in vocab else '<UNK>'
        
        for (t1, t2) in V[i]:
            for t3 in tagset:
                if t3 in ['<s>', '</s>']:
                    continue
                    
                trans_prob = prob_cache.trigram_prob(t1, t2, t3)
                emis_prob = prob_cache.emission_prob(t3, word)
                score = V[i][(t1, t2)] + trans_prob + emis_prob
                
                key = (t2, t3)
                if key not in V[i+1] or score > V[i+1][key]:
                    V[i+1][key] = score
                    backpointer[i+1][key] = (t1, t2)
    
    if not V[n]:
        return ['<UNK>'] * n
    
    best_final = max(V[n], key=V[n].get)
    
    tags = []
    current = best_final
    for i in range(n, 0, -1):
        tags.append(current[1])
        if i > 1 and current in backpointer[i]:
            current = backpointer[i][current]
    
    return list(reversed(tags))

class HMM:
    """HMM model for Vietnamese POS tagging"""
    def __init__(self, prob_cache=None):
        self.prob_cache = prob_cache
        self.train_results = None
        self.predictions = {}
        
    def fit(self, sentences_path, min_count=1, save=True):
        """Train the HMM model"""
        timer = Timer("HMM Training")
        
        emission_counts, trigram_counts, bigram_counts, unigram_counts, tagset, vocab = train_trigram_hmm(sentences_path, min_count)
        
        self.prob_cache = ProbabilityCache(emission_counts, trigram_counts, bigram_counts, tagset, vocab)
        
        # Create a simple data reader compatible with MEMM interface
        self.data = HMMDataReader(tagset, vocab, emission_counts)
        
        self.train_results = {
            'tagset_size': len(tagset),
            'vocab_size': len(vocab),
            'emission_features': len(emission_counts),
            'trigram_features': len(trigram_counts)
        }
        
        timer.stop()
        
        if save:
            self.save_model()
    
    def predict(self, data, cutoff=None):
        """Predict tags for sentences"""
        timer = Timer("HMM Inference")
        predictions = {}
        
        for i in range(data.getSentencesSize()):
            sentence = data.getSentenceByIndex(i)
            predicted_tags = viterbi_decode(sentence, self.prob_cache)
            predictions[i] = tuple(predicted_tags)
        
        self.predictions[data.file] = predictions
        timer.stop()
    
    def tag_sentence_words(self, words):
        """Tag a list of words"""
        if not self.prob_cache:
            return ['<UNK>'] * len(words)
        
        return viterbi_decode(words, self.prob_cache)
    
    def save_model(self, filepath=None):
        """Save HMM model"""
        if filepath is None:
            filepath = "./cache/HMM_model.pkl"
        
        save_hmm_model(self.prob_cache, filepath, self.train_results)
    
    @staticmethod  # Add this decorator
    def load_model(model_path):
        """Load a trained HMM model from disk with validation"""
        try:
            with open(model_path, 'rb') as f:
                model_data = pickle.load(f)
            
            # Validate required fields
            required_fields = ['emission_counts', 'trigram_counts', 'bigram_counts', 
                              'tagset', 'vocab', 'tagset_size', 'vocab_size']
            
            for field in required_fields:
                if field not in model_data:
                    raise ValueError(f"Missing required field: {field}")
            
            # Convert back to defaultdict if needed
            emission_counts = defaultdict(Counter)
            for tag, word_counts in model_data['emission_counts'].items():
                emission_counts[tag] = Counter(word_counts)
            
            trigram_counts = Counter(model_data['trigram_counts'])
            bigram_counts = Counter(model_data['bigram_counts'])
            
            # Recreate ProbabilityCache with loaded data
            prob_cache = ProbabilityCache(
                emission_counts,
                trigram_counts,
                bigram_counts,
                model_data['tagset'],
                model_data['vocab']
            )
            
            # Create HMM instance
            hmm = HMM(prob_cache)
            hmm.train_results = model_data.get('train_results')
            
            # Create compatible data reader with error handling
            try:
                hmm.data = HMMDataReader(
                    prob_cache.tagset,
                    prob_cache.vocab,
                    prob_cache.emission_counts
                )
            except Exception as e:
                print(f"Warning: Could not create HMMDataReader: {e}")
                # Create minimal data reader as fallback
                class MinimalHMMData:
                    def __init__(self, tagset, vocab):
                        self.tagset = tagset - {'<s>', '</s>'}
                        self.vocab = vocab
                        self.file = 'hmm_data'
                    
                    def getTagSet(self):
                        return tuple(self.tagset)
                    
                    def getTagSetSize(self):
                        return len(self.tagset)
                    
                    def getWordDictSize(self):
                        return len(self.vocab)
                    
                    def getSentencesSize(self):
                        return 0
                    
                    def getTagsSize(self):
                        return 0
                    
                    def getSentenceByIndex(self, i):
                        return []
                    
                    def getTagsByIndex(self, i):
                        return []
                
                hmm.data = MinimalHMMData(prob_cache.tagset, prob_cache.vocab)
            
            print(f"HMM model loaded from {model_path}")
            print(f"Tagset size: {len(prob_cache.tagset)}")
            print(f"Vocab size: {len(prob_cache.vocab)}")
            
            return hmm
            
        except FileNotFoundError:
            print(f"Model file not found: {model_path}")
            raise
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise e

# Also add the standalone save/load functions for compatibility:

def save_hmm_model(prob_cache, model_path, train_results=None):
    """Save the trained HMM model to disk with better error handling"""
    try:
        model_data = {
            'emission_counts': dict(prob_cache.emission_counts),  # Convert defaultdict to dict
            'trigram_counts': dict(prob_cache.trigram_counts),
            'bigram_counts': dict(prob_cache.bigram_counts),
            'tagset': prob_cache.tagset,
            'vocab': prob_cache.vocab,
            'tag_totals': prob_cache.tag_totals,
            'tagset_size': prob_cache.tagset_size,
            'vocab_size': prob_cache.vocab_size,
            'train_results': train_results,
            'model_type': 'HMM',
            'version': '1.0'  # Add version for compatibility checking
        }
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        
        with open(model_path, 'wb') as f:
            pickle.dump(model_data, f)
        
        print(f"HMM model saved to {model_path}")
        print(f"Tagset size: {len(prob_cache.tagset)}")
        print(f"Vocab size: {len(prob_cache.vocab)}")
        
    except Exception as e:
        print(f"Error saving model: {str(e)}")
        raise e

def load_hmm_model(model_path):
    """Load a trained HMM model from disk with validation"""
    try:
        with open(model_path, 'rb') as f:
            model_data = pickle.load(f)
        
        # Validate required fields
        required_fields = ['emission_counts', 'trigram_counts', 'bigram_counts', 
                          'tagset', 'vocab', 'tagset_size', 'vocab_size']
        
        for field in required_fields:
            if field not in model_data:
                raise ValueError(f"Missing required field: {field}")
        
        # Convert back to defaultdict if needed
        emission_counts = defaultdict(Counter)
        for tag, word_counts in model_data['emission_counts'].items():
            emission_counts[tag] = Counter(word_counts)
        
        trigram_counts = Counter(model_data['trigram_counts'])
        bigram_counts = Counter(model_data['bigram_counts'])
        
        # Recreate ProbabilityCache with loaded data
        prob_cache = ProbabilityCache(
            emission_counts,
            trigram_counts,
            bigram_counts,
            model_data['tagset'],
            model_data['vocab']
        )
        
        print(f"HMM model loaded from {model_path}")
        print(f"Tagset size: {len(prob_cache.tagset)}")
        print(f"Vocab size: {len(prob_cache.vocab)}")
        
        return prob_cache, model_data.get('train_results')
        
    except FileNotFoundError:
        print(f"Model file not found: {model_path}")
        raise
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise e

class HMMDataReader:
    """Data reader compatible with MEMM interface for HMM"""
    def __init__(self, tagset, vocab, emission_counts):
        self.tagset = tagset
        self.vocab = vocab
        self.emission_counts = emission_counts
        self.file = "hmm_data"
        self.sentences = []
        self.tags = []
    
    def getTagSet(self):
        return tuple(self.tagset - {'<s>', '</s>'})
    
    def getTagSetSize(self):
        return len(self.tagset - {'<s>', '</s>'})
    
    def getWordDictSize(self):
        return len(self.vocab)
    
    def getSentencesSize(self):
        return len(self.sentences)
    
    def getTagsSize(self):
        return len(self.tags)
    
    def getSentenceByIndex(self, index):
        return self.sentences[index]
    
    def getTagsByIndex(self, index):
        return self.tags[index]
    
    def set_test_data(self, sentences, tags):
        """Set test data for evaluation"""
        self.sentences = sentences
        self.tags = tags

class UnifiedTagger:
    """Unified tagger that can work with both MEMM and HMM models"""
    
    def __init__(self, model_path, model_type="MEMM", cutoff=3):
        self.cutoff = cutoff
        self.model_type = model_type
        
        if model_type == "MEMM":
            self.model = self.load_memm_model(model_path)
        elif model_type == "HMM":
            self.model = HMM.load_model(model_path)
        else:
            raise ValueError("model_type must be 'MEMM' or 'HMM'")
        
        # Use global segmenter instead of loading new one
        self.segmenter = get_global_segmenter()
    
    @staticmethod
    def load_memm_model(filepath):
        """Load MEMM model from file"""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        # Reconstruct the model
        model = MEMM(model_data['feature_factory'], 
                    regularizer=model_data['model_metadata']['regularizer'])
        
        # Set all the loaded components
        model.weights = model_data['weights']
        model.data = model_data['data_reader']
        model.feature_factory = model_data['feature_factory']
        model.train_results = model_data.get('train_results')
        
        return model
    
    def segment_text(self, text):
        """Always segment text using VnCoreNLP"""
        if self.segmenter is None:
            st.error("VnCoreNLP not available. Cannot process text.")
            return []
        
        try:
            segmented = self.segmenter.word_segment(text)
            return segmented
        except Exception as e:
            st.error(f"Segmentation error: {str(e)}")
            return []
    
    def tag_single_word(self, word):
        """Tag a single word using the selected model"""
        # Always use VnCoreNLP to segment the word (in case it's compound)
        segmented_sentences = self.segment_text(word)
        
        if not segmented_sentences:
            return 'Unknown'
        
        # Take the first segmented sentence
        segmented_word = segmented_sentences[0].strip()
        
        # Split into tokens (this preserves underscores in compound words)
        tokens = segmented_word.split()
        
        if not tokens:
            return 'Unknown'
        
        # If it's just one token, add a period for context
        if len(tokens) == 1:
            sentence_tokens = [tokens[0], '.']
        else:
            sentence_tokens = tokens
        
        if self.model_type == "MEMM":
            sentence_tuple = tuple(sentence_tokens)
            dummy_tags = ['O'] * len(sentence_tuple)
            
            # Run Viterbi algorithm
            viterbi = ViterbiAlgorithm(0, sentence_tuple, dummy_tags, self.model, cutoff=self.cutoff)
            viterbi.run()
            
            # Return only the tag for the first word
            tags = list(viterbi.getBestTagSequence())
            return tags[0] if tags else 'Unknown'
        
        elif self.model_type == "HMM":
            # Handle unknown words for HMM
            processed_tokens = []
            for token in sentence_tokens:
                if token in self.model.prob_cache.vocab:
                    processed_tokens.append(token)
                else:
                    processed_tokens.append('<UNK>')
            
            tags = self.model.tag_sentence_words(processed_tokens)
            return tags[0] if tags else 'Unknown'
    
    def tag_sentence(self, sentence):
        """Tag a sentence using the selected model"""
        # Always use VnCoreNLP to segment the sentence
        segmented_sentences = self.segment_text(sentence)
        
        if not segmented_sentences:
            return []
        
        # Process each segmented sentence
        all_results = []
        for segmented_sentence in segmented_sentences:
            # Split the segmented sentence into words (preserving underscores)
            words = segmented_sentence.split()
            
            if words:  # Only process non-empty sentences
                tags = self._tag_word_list(words)
                
                all_results.append({
                    'original_input': sentence,
                    'segmented_sentence': segmented_sentence,  # Keep original VnCoreNLP format!
                    'words': words,
                    'tags': tags
                })
        
        return all_results
    
    def _tag_word_list(self, words):
        """Tag a list of words using the selected model"""
        if not words:
            return []
        
        if self.model_type == "MEMM":
            sentence_tuple = tuple(words)
            dummy_tags = ['O'] * len(sentence_tuple)
            
            # Run Viterbi algorithm
            viterbi = ViterbiAlgorithm(0, sentence_tuple, dummy_tags, self.model, cutoff=self.cutoff)
            viterbi.run()
            
            return list(viterbi.getBestTagSequence())
        
        elif self.model_type == "HMM":
            # Handle unknown words for HMM
            processed_words = []
            for word in words:
                if word in self.model.prob_cache.vocab:
                    processed_words.append(word)
                else:
                    processed_words.append('<UNK>')
            
            return self.model.tag_sentence_words(processed_words)

def main():
    st.set_page_config(
        page_title="Bộ Gán Nhãn Từ Loại Tiếng Việt",
        page_icon="🏷️",
        layout="wide"
    )
    
    st.title("🏷️ Bộ Gán Nhãn Từ Loại Tiếng Việt")
    st.markdown("*Từ Nhóm KKK*")
    st.markdown("---")
    
    # Initialize VnCoreNLP once at the beginning
    with st.spinner("Đang khởi tạo VnCoreNLP..."):
        segmenter = get_global_segmenter()
        if segmenter is None:
            st.error("❌ Không thể tải VnCoreNLP! Vui lòng kiểm tra cài đặt.")
            return
        else:
            st.success("✅ VnCoreNLP đã được tải thành công!")
    
    # Thanh bên cho việc chọn mô hình
    st.sidebar.header("Cấu Hình Mô Hình")
    
    # Model type selection
    model_type = st.sidebar.selectbox(
        "Chọn Loại Mô Hình:",
        options=["MEMM", "HMM"],
        help="Chọn loại mô hình: MEMM (Maximum Entropy Markov Model) hoặc HMM (Hidden Markov Model)"
    )
    
    # Lựa chọn Tập dữ liệu/Mô hình
    if model_type == "MEMM":
        dataset_options = {
            "VNDT": "/kaggle/input/nlp_weights/pytorch/default/6/MEMM_VNDT.pkl",
            "Dataset tự tạo": "/kaggle/input/nlp_weights/pytorch/default/6/MEMM_custom.pkl"
        }
    else:  # HMM
        dataset_options = {
            "VNDT": "/kaggle/input/nlp_weights/pytorch/default/6/HMM_VNDT.pkl",
            "Dataset tự tạo": "/kaggle/input/nlp_weights/pytorch/default/6/HMM_custom.pkl"
        }
    
    selected_dataset = st.sidebar.selectbox(
        "Chọn Tập Dữ Liệu/Mô Hình:",
        options=list(dataset_options.keys()),
        help="Chọn tập dữ liệu được sử dụng để huấn luyện mô hình"
    )
    
    # Tham số ngưỡng (only for MEMM)
    if model_type == "MEMM":
        cutoff = st.sidebar.slider(
            "Ngưỡng Gán Nhãn:",
            min_value=1,
            max_value=5,
            value=1,
            help="Số lượng tối đa các nhãn có thể được xem xét cho mỗi từ"
        )
    else:
        cutoff = 3  # Default for HMM
    
    # Tải mô hình
    model_path = dataset_options[selected_dataset]
    
    try:
        # Use proper caching for the tagger loading
        @st.cache_resource
        def load_tagger_cached(model_path, model_type, cutoff):
            if model_type == "MEMM":
                return Tagger(model_path, cutoff=cutoff)
            else:
                return UnifiedTagger(model_path, model_type=model_type, cutoff=cutoff)
        
        with st.spinner(f"Đang tải mô hình {model_type}..."):
            tagger = load_tagger_cached(model_path, model_type, cutoff)
        
        st.sidebar.success(f"✅ Mô hình {model_type} được tải thành công!")
        
        # Get available tags from loaded model
        if model_type == "MEMM":
            available_tags = list(tagger.memm.data.getTagSet())
            model_info = {
                'word_dict_size': tagger.memm.data.getWordDictSize(),
                'feature_vector_length': tagger.memm.feature_factory.getFeaturesVectorLength(),
                'additional_info': None
            }
        else:
            available_tags = list(tagger.model.data.getTagSet())
            model_info = {
                'word_dict_size': tagger.model.data.getWordDictSize(),
                'feature_vector_length': None,
                'additional_info': {
                    'emission_features': len(tagger.model.prob_cache.emission_counts),
                    'trigram_features': len(tagger.model.prob_cache.trigram_counts)
                }
            }
        
        # Display model information
        with st.sidebar.expander("Thông Tin Mô Hình"):
            st.write(f"**Loại Mô Hình:** {model_type}")
            st.write(f"**Tập Dữ Liệu:** {selected_dataset}")
            st.write(f"**Kích Thước Từ Vựng:** {model_info['word_dict_size']:,}")
            st.write(f"**Kích Thước Tập Nhãn:** {len(available_tags)}")
            
            if model_type == "MEMM":
                st.write(f"**Độ Dài Vector Đặc Trưng:** {model_info['feature_vector_length']:,}")
            else:
                st.write(f"**Số Đặc Trưng Emission:** {model_info['additional_info']['emission_features']:,}")
                st.write(f"**Số Đặc Trưng Trigram:** {model_info['additional_info']['trigram_features']:,}")
            
            st.write(f"**Phân Đoạn:** Luôn được bật (VnCoreNLP)")
            
            # Display available tags
            st.write(f"**Các Nhãn Có Sẵn:** {', '.join(sorted(available_tags))}")
    
    except Exception as e:
        st.sidebar.error(f"❌ Lỗi khi tải mô hình: {str(e)}")
        st.error("Vui lòng kiểm tra đường dẫn tệp mô hình và thử lại.")
        return
    
    # Ý nghĩa nhãn động dựa trên các nhãn có sẵn
    def get_tag_meaning(tag, available_tags):
        """Lấy ý nghĩa nhãn, chỉ khi nhãn tồn tại trong mô hình"""
        if tag not in available_tags:
            return None
            
        tag_meanings = {
            'N': 'Danh từ',
            'V': 'Động từ', 
            'A': 'Tính từ',
            'R': 'Phó từ',
            'E': 'Giới từ',
            'M': 'Số từ',
            'L': 'Định từ',
            'P': 'Đại từ',
            'C': 'Liên từ',
            'I': 'Thán từ',
            'Np': 'Danh từ riêng',
            'CH': 'Dấu câu',
            'Nc': 'Danh từ chung',
            'Nu': 'Danh từ đơn vị',
            'Ny': 'Danh từ viết tắt',
            'Cc': 'Liên từ đẳng lập',
            'T':  'Trợ từ',
            'X': 'Không xác định',
            'Y': 'Từ viết tắt',
            'Vb': 'Động từ cơ bản',
            'NP': 'Cụm danh từ riêng',
            'Nb': 'Danh từ mượn',
            'S': 'Dấu hiệu câu',
            'WHNP': 'Cụm danh từ nghi vấn'
        }
        
        return tag_meanings.get(tag, f"Nhãn: {tag}")
    
    def validate_single_word(segmented_text):
        """Xác thực rằng văn bản được phân đoạn là một từ duy nhất (không có khoảng trắng)"""
        if ' ' in segmented_text.strip():
            return False, "Kết quả phân đoạn chứa nhiều từ. Vui lòng nhập một từ hoặc cụm từ duy nhất."
        return True, None
    
    # Khu vực nội dung chính
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.header("🔤 Gán Nhãn Một Từ")
        
        word_input = st.text_input(
            "Nhập một từ tiếng Việt:",
            placeholder="ví dụ: nhà, đẹp, học, sinh",
            help="Nhập một từ tiếng Việt duy nhất. Sau khi phân đoạn, nó phải là một từ (có hoặc không có dấu gạch dưới)."
        )
        
        if st.button("Gán Nhãn Từ", key="word_btn"):
            if word_input.strip():
                try:
                    with st.spinner("Đang phân đoạn và gán nhãn..."):
                        # Hiển thị kết quả phân đoạn trước
                        segmented = tagger.segment_text(word_input.strip())
                        
                        if segmented:
                            segmented_text = segmented[0].strip()  # Lấy câu phân đoạn đầu tiên
                            st.info(f"**Đã Phân Đoạn:** {segmented_text}")
                            
                            # Xác thực rằng đó là một từ duy nhất
                            is_valid, error_msg = validate_single_word(segmented_text)
                            
                            if not is_valid:
                                st.error(error_msg)
                                st.warning("Vui lòng thử nhập một từ duy nhất thay vì một cụm từ.")
                                return
                            
                            # Lấy nhãn bằng cách sử dụng văn bản đã phân đoạn
                            tag = tagger.tag_single_word(segmented_text)
                            
                            st.success(f"**Đầu Vào Ban Đầu:** {word_input}")
                            st.success(f"**Từ Đã Phân Đoạn:** {segmented_text}")
                            st.success(f"**Nhãn:** {tag}")
                            
                            # Hiển thị ý nghĩa nhãn nếu có
                            meaning = get_tag_meaning(tag, available_tags)
                            if meaning:
                                st.info(f"**Ý Nghĩa:** {meaning}")
                            else:
                                st.info(f"**Nhãn:** {tag} (ý nghĩa không có sẵn)")
                        else:
                            st.error("Không thể phân đoạn đầu vào.")
                        
                except Exception as e:
                    st.error(f"Lỗi khi xử lý đầu vào: {str(e)}")
            else:
                st.warning("Vui lòng nhập một từ!")
    
    with col2:
        st.header("📝 Gán Nhãn Văn Bản")
        
        sentence_input = st.text_area(
            "Nhập văn bản tiếng Việt:",
            placeholder="ví dụ: Ông Nguyễn Khắc Chúc đang làm việc tại Đại học Quốc gia Hà Nội. Bà Lan cũng làm việc tại đây.",
            height=120,
            help="Nhập văn bản tiếng Việt. Nó sẽ được tự động phân đoạn bằng VnCoreNLP."
        )
        
        if st.button("Gán Nhãn Văn Bản", key="sentence_btn"):
            if sentence_input.strip():
                try:
                    with st.spinner("Đang phân đoạn và gán nhãn văn bản..."):
                        # Tag the sentence directly using the appropriate tagger method
                        sentence_results = tagger.tag_sentence(sentence_input.strip())
                    
                    if sentence_results:
                        st.success("**Kết Quả Xử Lý:**")
                        
                        for i, result in enumerate(sentence_results):
                            words = result['words']
                            tags = result['tags']
                            segmented_sentence = result['segmented_sentence']
                            
                            if len(sentence_results) > 1:
                                st.subheader(f"Câu {i+1}:")
                            
                            st.write(f"**Ban Đầu:** {sentence_input}")
                            st.write(f"**Đã Phân Đoạn:** {segmented_sentence}")
                            
                            if len(words) == len(tags):
                                # Tạo bảng hiển thị đẹp mắt
                                result_data = []
                                for j, (word, tag) in enumerate(zip(words, tags)):
                                    meaning = get_tag_meaning(tag, available_tags)
                                    result_data.append({
                                        "Vị Trí": j + 1,
                                        "Từ": word,
                                        "Nhãn": tag,
                                        "Ý Nghĩa": meaning if meaning else f"Nhãn: {tag}"
                                    })
                                
                                st.table(result_data)
                                
                                # Hiển thị các nhãn dưới dạng huy hiệu màu
                                st.markdown("**Biểu Diễn Trực Quan:**")
                                
                                # Gán màu động cho các nhãn có sẵn
                                colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57', 
                                         '#FF9FF3', '#54A0FF', '#5F27CD', '#00D2D3', '#FF9F43', 
                                         '#FF3838', '#2F3542', '#FF6348', '#7F8C8D', '#9B59B6',
                                         '#E67E22', '#F39C12', '#D35400', '#C0392B', '#8E44AD']
                                
                                tag_colors = {}
                                for idx, tag in enumerate(sorted(available_tags)):
                                    tag_colors[tag] = colors[idx % len(colors)]
                                
                                html_output = ""
                                for word, tag in zip(words, tags):
                                    color = tag_colors.get(tag, '#95A5A6')
                                    html_output += f'<span style="background-color: {color}; padding: 4px 8px; margin: 2px; border-radius: 4px; color: white; font-weight: bold;">{word} <sub>{tag}</sub></span> '
                                
                                st.markdown(html_output, unsafe_allow_html=True)
                                
                                if len(sentence_results) > 1:
                                    st.markdown("---")
                            else:
                                st.error("Số lượng từ và nhãn không khớp!")
                    else:
                        st.error("Không thể xử lý văn bản.")
                        
                except Exception as e:
                    st.error(f"Lỗi khi xử lý văn bản: {str(e)}")
            else:
                st.warning("Vui lòng nhập văn bản!")
    
    # Phần ví dụ
    st.markdown("---")
    st.header("📚 Ví Dụ")
    
    col3, col4 = st.columns([1, 1])
    
    with col3:
        st.subheader("Ví Dụ Gán Nhãn Một Từ")
        example_words = [
            "nhà", 
            "đẹp", 
            "học",
            "cơm",
            "đi",
            "100"
        ]
        
        for word in example_words:
            if st.button(f"Gán Nhãn: {word}", key=f"example_word_{word}"):
                try:
                    with st.spinner(f"Đang xử lý {word}..."):
                        # Hiển thị phân đoạn
                        segmented = tagger.segment_text(word)
                        if segmented:
                            segmented_text = segmented[0].strip()
                            st.write(f"**Ban Đầu:** {word}")
                            st.write(f"**Đã Phân Đoạn:** {segmented_text}")
                            
                            # Xác thực từ đơn
                            is_valid, error_msg = validate_single_word(segmented_text)
                            if is_valid:
                                # Hiển thị nhãn bằng văn bản đã phân đoạn
                                tag = tagger.tag_single_word(segmented_text)
                                meaning = get_tag_meaning(tag, available_tags)
                                st.write(f"**Nhãn:** {tag}")
                                if meaning:
                                    st.write(f"**Ý Nghĩa:** {meaning}")
                            else:
                                st.write(f"**Lỗi:** {error_msg}")
                        else:
                            st.write(f"Không thể phân đoạn: {word}")
                except Exception as e:
                    st.write(f"Lỗi khi xử lý {word}: {str(e)}")
    
    with col4:
        st.subheader("Ví Dụ Gán Nhãn Văn Bản")
        example_texts = [
            "Tôi đang học tiếng Việt.",
            "Ông Nguyễn Khắc Chúc đang làm việc tại Đại học Quốc gia Hà Nội.",
            "Bà Lan, vợ ông Chúc, cũng làm việc tại đây.",
            "Sinh viên làm bài tập ở thư viện."
        ]
        
        for i, text in enumerate(example_texts):
            if st.button(f"Xử Lý: {text[:30]}...", key=f"example_text_{i}"):
                try:
                    with st.spinner(f"Đang xử lý văn bản {i+1}..."):
                        # Use the tag_sentence method directly
                        results = tagger.tag_sentence(text)
                        
                        for result in results:
                            words = result['words']
                            tags = result['tags']
                            
                            st.write(f"**Ban Đầu:** {text}")
                            st.write(f"**Đã Phân Đoạn:** {result['segmented_sentence']}")
                            
                            result_str = " ".join([f"{w}({t})" for w, t in zip(words, tags)])
                            st.write(f"**Đã Gán Nhãn:** {result_str}")
                            st.write("---")
                except Exception as e:
                    st.write(f"Lỗi khi xử lý văn bản: {str(e)}")
    
    # Phần thông tin
    st.markdown("---")
    st.header("ℹ️ Thông Tin")
    
    with st.expander("Về Gán Nhãn Một Từ"):
        st.markdown("""
        **Quy Tắc Nhập Từ:**
        - Đầu vào phải là một từ hoặc cụm từ duy nhất
        - Sau khi phân đoạn bằng VnCoreNLP, kết quả phải là một token
        - Ví dụ hợp lệ: "nhà" → "nhà", "học sinh" → "học_sinh"
        - Không hợp lệ: các đầu vào phân đoạn thành nhiều từ cách nhau bằng khoảng trắng
        
        **Ví Dụ:**
        - ✅ "nhã" → "nhà" (từ đơn)
        - ✅ "học sinh" → "học_sinh" (từ ghép có dấu gạch dưới)
        - ❌ "tôi đi học" → "tôi đi học" (nhiều từ)
        """)
    
    with st.expander("Về Tích Hợp VnCoreNLP"):
        st.markdown("""
        **Phân Đoạn Tự Động:**
        - Tất cả đầu vào được tự động phân đoạn bằng VnCoreNLP
        - Các từ ghép được nối đúng cách bằng dấu gạch dưới
        - Nhiều câu được xử lý riêng biệt
        
        **Ví Dụ Phân Đoạn:**
        - "học sinh" → "học_sinh"
        - "Nguyễn Khắc Chúc" → "Nguyễn_Khắc_Chúc"
        - "làm việc" → "làm_việc"
        - "Đại học Quốc gia" → "Đại_học Quốc_gia"
        """)
    
    with st.expander("Ý Nghĩa Nhãn Từ Loại"):
        # Hiển thị ý nghĩa nhãn động
        st.markdown("**Các Nhãn Có Sẵn Trong Mô Hình Hiện Tại:**")
        
        meanings_text = ""
        for tag in sorted(available_tags):
            meaning = get_tag_meaning(tag, available_tags)
            meanings_text += f"- **{tag}**: {meaning}\n"
        
        st.markdown(meanings_text)
    
    # Chân trang
    st.markdown("---")
    st.markdown(
        """
        <div style='text-align: center'>
            <p>Bộ Gán Nhãn Từ Loại Tiếng Việt với Phân Đoạn Tự Động VnCoreNLP | Được xây dựng bằng Streamlit</p>
        </div>
        """, 
        unsafe_allow_html=True
    )

if __name__ == "__main__":
    main()

Overwriting app_temp.py


In [10]:
!wget -q -O - ipv4.icanhazip.com

34.23.96.165


In [None]:
!streamlit run app_temp.py & npx localtunnel --port 8501

[1G[0K⠙
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[1G[0K⠹[1G[0K⠸[1G[0K⠼[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8502[0m
[34m  Network URL: [0m[1mhttp://172.19.2.2:8502[0m
[34m  External URL: [0m[1mhttp://34.23.96.165:8502[0m
[0m
[1G[0K⠴[1G[0K⠦[1G[0Kyour url is: https://cute-dragons-beg.loca.lt
