###Set up Library

In [None]:
!pip install tensorflow>=2.10
!pip install numpy
!pip install pandas
!pip install scikit-learn
!pip install gensim
!pip install wordninja
!pip install emoji
!pip install datasets

In [None]:
!pip install ftfy nlpaug imbalanced-learn

In [None]:
from google.colab import drive
drive.mount('/content/drive')

###Model

In [None]:
import os
import re
import pickle
import random
import argparse
import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.mixed_precision import set_global_policy
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix, precision_recall_fscore_support
import gensim
import string
import sys
try:
    import emoji
except ModuleNotFoundError:
    print("Installing emoji...")
    os.system("pip install emoji")
    import emoji
try:
    import wordninja
except ModuleNotFoundError:
    print("Installing wordninja...")
    os.system("pip install wordninja")
    import wordninja
try:
    import nlpaug.augmenter.word as naw
except ModuleNotFoundError:
    print("Installing nlpaug...")
    os.system("pip install nlpaug")
    import nlpaug.augmenter.word as naw
try:
    import gensim
except ModuleNotFoundError:
    print("Installing gensim...")
    os.system("pip install gensim")
    import gensim
try:
    import ftfy
except ModuleNotFoundError:
    print("Installing ftfy...")
    os.system("pip install ftfy")
    import ftfy

# Thi·∫øt l·∫≠p logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ƒê·ªãnh nghƒ©a c√°c √°nh x·∫° cho ti·ªÅn x·ª≠ l√Ω
contraction_mapping = {
    "ain't": "is not", "aren't": "are not", "can't": "cannot", "'cause": "because",
    "could've": "could have", "couldn't": "could not", "didn't": "did not",
    "doesn't": "does not", "don't": "do not", "hadn't": "had not",
    "hasn't": "has not", "haven't": "have not", "he'd": "he would",
    "he'll": "he will", "he's": "he is", "how'd": "how did",
    "how'd'y": "how do you", "how'll": "how will", "how's": "how is",
    "I'd": "I would", "I'd've": "I would have", "I'll": "I will",
    "I'll've": "I will have", "I'm": "I am", "I've": "I have",
    "i'd": "i would", "i'd've": "i would have", "i'll": "i will",
    "i'll've": "i will have", "i'm": "i am", "i've": "i have",
    "isn't": "is not", "it'd": "it would", "it'd've": "it would have",
    "it'll": "it will", "it'll've": "it will have", "it's": "it is",
    "let's": "let us", "ma'am": "madam", "mayn't": "may not",
    "might've": "might have", "mightn't": "might not",
    "mightn't've": "might not have", "must've": "must have",
    "mustn't": "must not", "mustn't've": "must not have",
    "needn't": "need not", "needn't've": "need not have",
    "o'clock": "of the clock", "oughtn't": "ought not",
    "oughtn't've": "ought not have", "shan't": "shall not",
    "sha'n't": "shall not", "shan't've": "shall not have",
    "she'd": "she would", "she'd've": "she would have",
    "she'll": "she will", "she'll've": "she will have", "she's": "she is",
    "should've": "should have", "shouldn't": "should not",
    "shouldn't've": "should not have", "so've": "so have", "so's": "so as",
    "this's": "this is", "that'd": "that would", "that'd've": "that would have",
    "that's": "that is", "there'd": "there would", "there'd've": "there would have",
    "there's": "there is", "here's": "here is", "they'd": "they would",
    "they'd've": "they would have", "they'll": "they will",
    "they'll've": "they will have", "they're": "they are",
    "they've": "they have", "to've": "to have", "wasn't": "was not",
    "we'd": "we would", "we'd've": "we would have", "we'll": "we will",
    "we'll've": "we will have", "we're": "we are", "we've": "we have",
    "weren't": "were not", "what'll": "what will", "what'll've": "what will have",
    "what're": "what are", "what's": "what is", "what've": "what have",
    "when's": "when is", "when've": "when have", "where'd": "where did",
    "where's": "where is", "where've": "where have", "who'll": "who will",
    "who'll've": "who will have", "who's": "who is", "who've": "who have",
    "why's": "why is", "why've": "why have", "will've": "will have",
    "won't": "will not", "won't've": "will not have", "would've": "would have",
    "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all",
    "y'all'd": "you all would", "y'all'd've": "you all would have",
    "y'all're": "you all are", "y'all've": "you all have", "you'd": "you would",
    "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have",
    "you're": "you are", "you've": "you have", "u.s": "america", "e.g": "for example"
}

punct_mapping = {
    "‚Äò": "'", "‚Çπ": "e", "¬¥": "'", "¬∞": "", "‚Ç¨": "e", "‚Ñ¢": "tm", "‚àö": " sqrt ",
    "√ó": "x", "¬≤": "2", "‚Äî": "-", "‚Äì": "-", "‚Äô": "'", "_": "-", "`": "'", '‚Äú': '"',
    '‚Äù': '"', '‚Äú': '"', "¬£": "e", '‚àû': 'infinity', 'Œ∏': 'theta', '√∑': '/', 'Œ±': 'alpha',
    '‚Ä¢': '.', '√†': 'a', '‚àí': '-', 'Œ≤': 'beta', '‚àÖ': '', '¬≥': '3', 'œÄ': 'pi'
}

mispell_dict = {
    'colour': 'color', 'centre': 'center', 'favourite': 'favorite', 'travelling': 'traveling',
    'counselling': 'counseling', 'theatre': 'theater', 'cancelled': "canceled", 'labour': 'labor',
    'organisation': "organization", 'wwii': 'world war 2', 'citicise': 'criticize', 'youtu ': 'youtube ',
    'Qoura': 'Quora', 'sallary': 'salary', 'Whta': 'What', 'narcisist': 'narcissist', 'howdo': 'how do',
    'whatare': 'what are', 'howcan': "how can", 'howmuch': 'how much', 'howmany': 'how many',
    'whydo': 'why do', 'doI': 'do I', 'theBest': 'the best', 'howdoes': 'how does',
    'mastrubation': 'masturbation', 'mastrubate': 'masturbate', 'mastrubating': 'masturbating',
    'pennis': 'penis', 'Etherium': 'Ethereum', 'narcissit': 'narcissist', 'bigdata': 'big data',
    '2k17': '2017', '2k18': '2018', 'qouta': 'quota', 'exboyfriend': 'ex boyfriend',
    'airhostess': 'air hostess', 'whst': 'what', 'watsapp': 'whatsapp',
    'demonitisation': 'demonetization', 'demonitization': 'demonetization',
    'demonetisation': "demonetization", 'pissed': 'pissed'
}

punct_chars = list((set(string.punctuation) | {
    "‚Äô", "‚Äò", "‚Äì", "‚Äî", "~", "|", "‚Äú", "‚Äù", "‚Ä¶", "'", "`", "_", "‚Äú"
}) - set(["#", "!", "?"]))
punct_chars.sort()
punctuation = "".join(punct_chars)
replace = re.compile("[%s]" % re.escape(punctuation))

# H√†m t·∫£i GloVe v√† emoji2vec
def load_embeddings(glove_path, emoji2vec_path, vocab_size, embedding_dim, logger):
    logger.info(f"ƒêang t·∫£i GloVe t·ª´ {glove_path}")
    glove_embeddings = {}
    try:
        with open(glove_path, 'r', encoding='utf-8') as f:
            for line in f:
                values = line.split()
                word = values[0]
                vector = np.asarray(values[1:], dtype='float32')
                glove_embeddings[word] = vector
        logger.info(f"ƒê√£ t·∫£i {len(glove_embeddings)} vector GloVe")
    except Exception as e:
        logger.error(f"L·ªói khi t·∫£i GloVe t·ª´ {glove_path}: {str(e)}")
        raise

    logger.info(f"ƒêang t·∫£i emoji2vec t·ª´ {emoji2vec_path}")
    emoji2vec = None
    try:
        with open(emoji2vec_path, 'r', encoding='utf-8') as f:
            first_line = f.readline()
            if not first_line.strip():
                raise ValueError(f"T·ªáp emoji2vec {emoji2vec_path} r·ªóng")
        emoji2vec = gensim.models.KeyedVectors.load_word2vec_format(
            emoji2vec_path, binary=False, unicode_errors='ignore'
        )
        logger.info(f"ƒê√£ t·∫£i {len(emoji2vec.key_to_index)} vector emoji")
    except Exception as e:
        logger.error(f"L·ªói khi t·∫£i emoji2vec t·ª´ {emoji2vec_path}: {str(e)}")
        raise
    return glove_embeddings, emoji2vec

# H√†m t·∫°o ma tr·∫≠n embedding
def create_embedding_matrix(tokenizer, glove_embeddings, emoji2vec, vocab_size, embedding_dim, logger):
    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    glove_hits, emoji_hits, misses = 0, 0, 0
    for word, idx in tokenizer.word_index.items():
        if idx >= vocab_size:
            continue
        if emoji.is_emoji(word) and word in emoji2vec:
            embedding_matrix[idx] = emoji2vec[word]
            emoji_hits += 1
        elif word in glove_embeddings:
            embedding_matrix[idx] = glove_embeddings[word]
            glove_hits += 1
        else:
            misses += 1
    logger.info(f"Ma tr·∫≠n embedding: {glove_hits} GloVe hits, {emoji_hits} emoji2vec hits, {misses} misses")
    return embedding_matrix

# H√†m ti·ªÅn x·ª≠ l√Ω vƒÉn b·∫£n
def clean_text(text, logger):
    if not isinstance(text, str) or not text:
        return ""
    logger.debug(f"VƒÉn b·∫£n g·ªëc: {text}")

    # S·ª≠a l·ªói m√£ h√≥a k√Ω t·ª±
    text = ftfy.fix_text(text)

    # X·ª≠ l√Ω contractions
    for contraction, full_form in contraction_mapping.items():
        text = text.replace(contraction, full_form)

    # X·ª≠ l√Ω k√Ω t·ª± ƒë·∫∑c bi·ªát
    for p, replacement in punct_mapping.items():
        text = text.replace(p, replacement)

    # X·ª≠ l√Ω hashtag
    def split_hashtag(match):
        hashtag = match.group(0)[1:]
        words = wordninja.split(hashtag)
        return ' '.join(words)
    text = re.sub(r"#\w+", split_hashtag, text)

    # Lo·∫°i b·ªè URL v√† mention
    text = re.sub(r"http\S*|\S*\.com\S*|\S*www\S*", " ", text)
    text = re.sub(r"\s@\S+", " ", text)

    # Lo·∫°i b·ªè d·∫•u c√¢u
    text = replace.sub(" ", text)

    # Chuy·ªÉn th√†nh ch·ªØ th∆∞·ªùng v√† s·ª≠a l·ªói ch√≠nh t·∫£
    text = text.lower()
    words = text.split()
    words = [mispell_dict.get(word, word) for word in words]
    text = ' '.join(words)
    text = re.sub(r"\s+", " ", text).strip()

    logger.debug(f"VƒÉn b·∫£n ƒë√£ x·ª≠ l√Ω: {text}")
    return text

# L·ªõp x·ª≠ l√Ω d·ªØ li·ªáu GoEmotions
class GoemotionsProcessor:
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger

    def get_labels(self):
        label_file = os.path.join(self.args.data_dir, self.args.label_file)
        self.logger.info(f"ƒêang ƒë·ªçc t·ªáp nh√£n t·∫°i: {label_file}")
        if not os.path.exists(label_file):
            self.logger.warning(f"Kh√¥ng t√¨m th·∫•y t·ªáp nh√£n t·∫°i {label_file}. S·ª≠ d·ª•ng nh√£n m·∫∑c ƒë·ªãnh GoEmotions.")
            labels = [
                'admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion',
                'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment',
                'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism',
                'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'
            ]
            self.logger.info(f"S·ª≠ d·ª•ng {len(labels)} nh√£n m·∫∑c ƒë·ªãnh")
            return labels
        try:
            with open(label_file, "r", encoding="utf-8") as f:
                labels = [line.strip() for line in f if line.strip()]
            if not labels:
                self.logger.error(f"T·ªáp nh√£n {label_file} r·ªóng")
                raise ValueError(f"T·ªáp nh√£n {label_file} r·ªóng")
            self.logger.info(f"ƒê√£ ƒë·ªçc {len(labels)} nh√£n t·ª´ {label_file}")
            return labels
        except Exception as e:
            self.logger.error(f"L·ªói khi ƒë·ªçc t·ªáp nh√£n {label_file}: {str(e)}")
            raise

    def _read_file(self, input_file):
        if not os.path.exists(input_file):
            self.logger.error(f"Kh√¥ng t√¨m th·∫•y t·ªáp d·ªØ li·ªáu t·∫°i {input_file}")
            raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y t·ªáp d·ªØ li·ªáu t·∫°i {input_file}")
        try:
            df = pd.read_csv(input_file, sep='\t', header=None, names=['text', 'labels', 'id'])
            self.logger.info(f"ƒê√£ ƒë·ªçc {len(df)} d√≤ng t·ª´ {input_file}")
            if df.empty:
                self.logger.error(f"T·ªáp d·ªØ li·ªáu {input_file} r·ªóng")
                raise ValueError(f"T·ªáp d·ªØ li·ªáu {input_file} r·ªóng")
            return df
        except Exception as e:
            self.logger.error(f"L·ªói khi ƒë·ªçc t·ªáp {input_file}: {str(e)}")
            raise

    def _augment_data(self, texts, labels):
        aug = naw.SynonymAug(aug_p=0.3)
        augmented_texts, augmented_labels = [], []
        for text, label in zip(texts, labels):
            augmented_texts.append(text)
            augmented_labels.append(label)
            aug_text = aug.augment(text)[0]
            augmented_texts.append(aug_text)
            augmented_labels.append(label)
        self.logger.info(f"S·ªë m·∫´u sau tƒÉng c∆∞·ªùng: {len(augmented_texts)}")
        return augmented_texts, augmented_labels

    def _balance_labels(self, examples, label_list_len, set_type):
        if set_type != "train":
            return examples
        self.logger.info("C√¢n b·∫±ng nh√£n cho t·∫≠p hu·∫•n luy·ªán")
        label_counts = Counter()
        for ex in examples:
            label_counts.update(ex['labels'])
        self.logger.info(f"Ph√¢n b·ªë nh√£n ban ƒë·∫ßu: {dict(label_counts)}")
        counts = [count for count in label_counts.values() if count > 0]
        target_count = min(int(np.median(counts) * 6.0), len(examples) // 2)
        self.logger.info(f"S·ªë l∆∞·ª£ng m·ª•c ti√™u m·ªói nh√£n: {target_count}")
        balanced_examples = []
        for label in range(label_list_len):
            samples_with_label = [ex for ex in examples if label in ex['labels']]
            current_count = label_counts[label]
            if current_count == 0:
                continue
            elif current_count < target_count:
                samples_needed = target_count - current_count
                oversampled = random.choices(samples_with_label, k=samples_needed)
                balanced_examples.extend(oversampled)
            else:
                samples_to_keep = max(target_count, int(current_count * 0.5))
                balanced_examples.extend(random.sample(samples_with_label, min(samples_to_keep, len(samples_with_label))))
        balanced_examples.extend([ex for ex in examples if not any(label in ex['labels'] for label in range(label_list_len))])
        random.shuffle(balanced_examples)
        new_label_counts = Counter()
        for ex in balanced_examples:
            new_label_counts.update(ex['labels'])
        self.logger.info(f"Ph√¢n b·ªë nh√£n sau c√¢n b·∫±ng: {dict(new_label_counts)}")
        self.logger.info(f"S·ªë m·∫´u ban ƒë·∫ßu: {len(examples)}, S·ªë m·∫´u sau c√¢n b·∫±ng: {len(balanced_examples)}")
        return balanced_examples

    def get_examples(self, mode):
        file_map = {
            'train': self.args.train_file,
            'dev': self.args.dev_file,
            'test': self.args.test_file
        }
        file_to_read = file_map.get(mode)
        if not file_to_read:
            raise ValueError("Mode ph·∫£i l√† 'train', 'dev', ho·∫∑c 'test'")
        file_path = os.path.join(self.args.data_dir, file_to_read)
        self.logger.info(f"ƒêang ƒë·ªçc d·ªØ li·ªáu {mode} t·ª´ {file_path}")
        df = self._read_file(file_path)
        return self._create_examples(df, mode)

    def _create_examples(self, df, set_type):
        examples = []
        label_list_len = len(self.get_labels())
        label_counts = Counter()
        for i, row in df.iterrows():
            guid = f"{set_type}-{i}"
            raw_text = row['text']
            label_str = str(row['labels'])
            try:
                label = [int(l) for l in label_str.split(',') if l.strip().isdigit()]
                label = [l for l in label if 0 <= l < label_list_len]
                if not label:
                    self.logger.warning(f"Kh√¥ng c√≥ nh√£n h·ª£p l·ªá t·∫°i d√≤ng {i}: {label_str}. B·ªè qua.")
                    continue
                label_counts.update(label)
            except (ValueError, IndexError) as e:
                self.logger.warning(f"Nh√£n kh√¥ng h·ª£p l·ªá t·∫°i d√≤ng {i}: {label_str}. B·ªè qua. L·ªói: {e}")
                continue
            cleaned_text = clean_text(raw_text, self.logger)
            examples.append({
                'guid': guid,
                'text': cleaned_text,
                'labels': label
            })
        self.logger.info(f"ƒê√£ t·∫°o {len(examples)} m·∫´u t·ª´ {set_type}")
        self.logger.info(f"Ph√¢n b·ªë nh√£n cho {set_type}: {dict(label_counts)}")
        if not examples:
            self.logger.error(f"Kh√¥ng t·∫°o ƒë∆∞·ª£c m·∫´u t·ª´ {set_type}. Ki·ªÉm tra t·ªáp d·ªØ li·ªáu!")
            raise ValueError(f"Kh√¥ng t·∫°o ƒë∆∞·ª£c m·∫´u t·ª´ {set_type}")
        examples = self._balance_labels(examples, label_list_len, set_type)
        return examples

# H√†m t·∫£i d·ªØ li·ªáu
def load_data(args, logger):
    processor = GoemotionsProcessor(args, logger)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    def load_and_cache(mode):
        cached_file = os.path.join(args.data_dir, f"cached_{mode}_data.pkl")
        if os.path.exists(cached_file):
            logger.info(f"ƒêang t·∫£i d·ªØ li·ªáu ƒë√£ cache t·ª´ {cached_file}")
            with open(cached_file, 'rb') as f:
                data = pickle.load(f)
            return data['texts'], data['labels']
        logger.info(f"ƒêang x·ª≠ l√Ω d·ªØ li·ªáu {mode}")
        examples = processor.get_examples(mode)
        texts = [ex['text'] for ex in examples]
        labels = [ex['labels'] for ex in examples]
        if mode == 'train':
            texts, labels = processor._augment_data(texts, labels)
        with open(cached_file, 'wb') as f:
            pickle.dump({'texts': texts, 'labels': labels}, f)
        logger.info(f"ƒê√£ cache d·ªØ li·ªáu {mode} v√†o {cached_file}")
        return texts, labels
    try:
        train_texts, train_labels = load_and_cache('train')
        val_texts, val_labels = load_and_cache('dev')
        test_texts, test_labels = load_and_cache('test')
        logger.info(f"S·ªë m·∫´u hu·∫•n luy·ªán: {len(train_texts)}, S·ªë m·∫´u x√°c th·ª±c: {len(val_texts)}, S·ªë m·∫´u ki·ªÉm tra: {len(test_texts)}")
    except Exception as e:
        logger.error(f"L·ªói khi t·∫£i d·ªØ li·ªáu: {str(e)}")
        raise
    return train_texts, val_texts, test_texts, train_labels, val_labels, test_labels, label_list

# H√†m m√£ h√≥a nh√£n th√†nh d·∫°ng multi-hot
def to_multi_hot(label_lists, num_labels):
    m = np.zeros((len(label_lists), num_labels), dtype=np.int32)
    for i, labs in enumerate(label_lists):
        m[i, labs] = 1
    return m

# H√†m t√≠nh tr·ªçng s·ªë l·ªõp s·ª≠ d·ª•ng log
def compute_class_weights(labels, num_labels, logger):
    label_counts = np.zeros(num_labels)
    for labs in labels:
        for l in labs:
            label_counts[l] += 1

    epsilon = 1e-8
    label_counts = np.maximum(label_counts, epsilon)

    total_samples = len(labels)
    class_weights = {}
    median_count = np.median(label_counts)
    for i in range(num_labels):
        if label_counts[i] > 0:
            class_weights[i] = np.log(total_samples / label_counts[i])
            if label_counts[i] < median_count:
                class_weights[i] *= 2.0
            class_weights[i] = max(class_weights[i], 1.0)
        else:
            class_weights[i] = 1.0

    weight_sum = sum(class_weights.values())
    if weight_sum > 0:
        scale_factor = num_labels / weight_sum
        for i in range(num_labels):
            class_weights[i] *= scale_factor

    logger.info(f"Ph√¢n b·ªë nh√£n: {label_counts}")
    logger.info(f"Tr·ªçng s·ªë l·ªõp (s·ª≠ d·ª•ng log): {class_weights}")
    return class_weights

# H√†m t·ªëi ∆∞u h√≥a ng∆∞·ª°ng
def optimize_threshold(y_true, y_pred_probs, logger):
    best_threshold = 0.5
    best_macro_f1 = 0.0
    thresholds = np.arange(0.1, 0.91, 0.05)
    for threshold in thresholds:
        y_pred = (y_pred_probs >= threshold).astype(int)
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
        if macro_f1 > best_macro_f1:
            best_macro_f1 = macro_f1
            best_threshold = threshold
    logger.info(f"Ng∆∞·ª°ng t·ªët nh·∫•t: {best_threshold}, Macro F1: {best_macro_f1:.4f}")
    return best_threshold

# H√†m t·∫°o t·∫•t c·∫£ ma tr·∫≠n nh·∫ßm l·∫´n trong m·ªôt h√¨nh (subplots)
def plot_confusion_matrices(y_true, y_pred, label_list, output_dir, model_type, logger):
    os.makedirs(output_dir, exist_ok=True)
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    num_labels = len(label_list)
    rows, cols = 5, 6  # L∆∞·ªõi 5x6 cho 28 nh√£n
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 2.5), constrained_layout=True)
    axes = axes.flatten()

    for i, label in enumerate(label_list):
        cm = confusion_matrix(y_true[:, i], y_pred[:, i])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'],
                    yticklabels=['Negative', 'Positive'], ax=axes[i], cbar=False)
        axes[i].set_title(f'{label}', fontsize=10)
        axes[i].set_xlabel('Predicted', fontsize=8)
        axes[i].set_ylabel('True', fontsize=8)
        axes[i].tick_params(labelsize=8)

    for j in range(len(label_list), len(axes)):
        axes[j].axis('off')

    cm_file = os.path.join(output_dir, f'all_confusion_matrices_{timestamp}.png')
    plt.savefig(cm_file, bbox_inches='tight', dpi=300)
    plt.close()
    logger.info(f"ƒê√£ l∆∞u t·∫•t c·∫£ ma tr·∫≠n nh·∫ßm l·∫´n cho {model_type} v√†o {cm_file}")

# H√†m t·∫°o ma tr·∫≠n nh·∫ßm l·∫´n t·ªïng h·ª£p
def plot_aggregated_confusion_matrix(y_true, y_pred, output_dir, model_type, logger):
    os.makedirs(output_dir, exist_ok=True)
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    num_labels = y_true.shape[1]
    aggregated_cm = np.zeros((2, 2), dtype=int)

    for i in range(num_labels):
        cm = confusion_matrix(y_true[:, i], y_pred[:, i])
        aggregated_cm += cm

    plt.figure(figsize=(6, 4))
    sns.heatmap(aggregated_cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'],
                yticklabels=['Negative', 'Positive'])
    plt.title(f'Aggregated Confusion Matrix Across All Labels ({model_type})')
    plt.ylabel('True')
    plt.xlabel('Predicted')
    cm_file = os.path.join(output_dir, f'aggregated_confusion_matrix_{timestamp}.png')
    plt.savefig(cm_file, bbox_inches='tight', dpi=300)
    plt.close()
    logger.info(f"ƒê√£ l∆∞u ma tr·∫≠n nh·∫ßm l·∫´n t·ªïng h·ª£p cho {model_type} v√†o {cm_file}")

# H√†m t·∫°o heatmap c·ªßa c√°c ch·ªâ s·ªë hi·ªáu su·∫•t
def plot_performance_metrics_heatmap(y_true, y_pred, label_list, output_dir, model_type, logger):
    os.makedirs(output_dir, exist_ok=True)
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    metrics = np.array([precision, recall, f1]).T  # Ma tr·∫≠n: h√†ng l√† nh√£n, c·ªôt l√† [precision, recall, f1]

    plt.figure(figsize=(10, 12))
    sns.heatmap(metrics, annot=True, fmt='.3f', cmap='YlGnBu', xticklabels=['Precision', 'Recall', 'F1-Score'],
                yticklabels=label_list)
    plt.title(f'Performance Metrics Heatmap ({model_type})')
    heatmap_file = os.path.join(output_dir, f'performance_metrics_heatmap_{timestamp}.png')
    plt.savefig(heatmap_file, bbox_inches='tight', dpi=300)
    plt.close()
    logger.info(f"ƒê√£ l∆∞u heatmap ch·ªâ s·ªë hi·ªáu su·∫•t cho {model_type} v√†o {heatmap_file}")

# H√†m t·∫°o bi·ªÉu ƒë·ªì loss
def plot_loss(history, output_dir, model_type, logger):
    os.makedirs(output_dir, exist_ok=True)
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(history.history['loss']) + 1)

    plt.plot(epochs, history.history['loss'], label='Training Loss', color='blue', linewidth=2)
    plt.plot(epochs, history.history['val_loss'], label='Validation Loss', color='orange', linewidth=2)

    plt.title(f'Training and Validation Loss for {model_type}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    loss_plot_file = os.path.join(output_dir, f'loss_plot_{timestamp}.png')
    plt.savefig(loss_plot_file, bbox_inches='tight', dpi=300)
    plt.close()
    logger.info(f"ƒê√£ l∆∞u bi·ªÉu ƒë·ªì loss cho {model_type} v√†o {loss_plot_file}")

# H√†m d·ª± ƒëo√°n cho c√°c v√≠ d·ª• m·ªõi
def predict_pipeline(model, tokenizer, texts, label_list, max_len, threshold=0.25, logger=None):
    cleaned_texts = [clean_text(text, logger) for text in texts]
    sequences = tokenizer.texts_to_sequences(cleaned_texts)
    padded = pad_sequences(sequences, maxlen=max_len, padding="post", truncating="post")
    pred_probs = model.predict(padded, verbose=0)
    predictions = (pred_probs >= threshold).astype(int)
    results = []
    for i, text in enumerate(texts):
        labels = [label_list[j] for j in range(len(label_list)) if predictions[i][j] == 1]
        top_indices = np.argsort(pred_probs[i])[-3:][::-1]
        top_labels = [label_list[j] for j in top_indices]
        top_probs = [pred_probs[i][j] for j in top_indices]
        results.append({
            "text": text,
            "labels": labels if labels else "Kh√¥ng c√≥ nh√£n",
            "top_labels": list(zip(top_labels, top_probs)),
            "probs": pred_probs[i].tolist()
        })
        print(f"\nV√≠ d·ª• {i+1}: {text}")
        print(f"D·ª± ƒëo√°n (ng∆∞·ª°ng {threshold}): {labels if labels else 'Kh√¥ng c√≥ nh√£n'}")
        print(f"Top-3 nh√£n: {list(zip(top_labels, top_probs))}")
    return results

# H√†m ch√≠nh
def main():
    # C·∫•u h√¨nh tham s·ªë
    parser = argparse.ArgumentParser(description="BiLSTM Multi-Label Classification for GoEmotions")
    parser.add_argument("--data_dir", default="/content/drive/MyDrive/Goemotions/data", type=str, help="Th∆∞ m·ª•c ch·ª©a d·ªØ li·ªáu")
    parser.add_argument("--train_file", default="train.tsv", type=str, help="T·ªáp d·ªØ li·ªáu hu·∫•n luy·ªán")
    parser.add_argument("--dev_file", default="dev.tsv", type=str, help="T·ªáp d·ªØ li·ªáu x√°c th·ª±c")
    parser.add_argument("--test_file", default="test.tsv", type=str, help="T·ªáp d·ªØ li·ªáu ki·ªÉm tra")
    parser.add_argument("--label_file", default="labels.txt", type=str, help="T·ªáp ch·ª©a danh s√°ch nh√£n")
    parser.add_argument("--glove_path", default="/content/drive/MyDrive/Goemotions/glove.6B.300d.txt", type=str, help="ƒê∆∞·ªùng d·∫´n ƒë·∫øn t·ªáp GloVe")
    parser.add_argument("--emoji2vec_path", default="/content/drive/MyDrive/Goemotions/emoji2vec.txt", type=str, help="ƒê∆∞·ªùng d·∫´n ƒë·∫øn t·ªáp emoji2vec")
    parser.add_argument("--ckpt_dir", default="/content/drive/MyDrive/Goemotions/checkpoints", type=str, help="Th∆∞ m·ª•c l∆∞u checkpoint m√¥ h√¨nh")
    parser.add_argument("--model_type", default="goemotions-bilstm", type=str, help="Lo·∫°i m√¥ h√¨nh (goemotions-bilstm ho·∫∑c vaafi-bilstm)")

    args = parser.parse_args([arg for arg in sys.argv[1:] if not arg.startswith('-f') and not arg.endswith('.json')])

    # Mount Google Drive
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
        print("Google Drive mounted successfully")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")

    # T·∫°o th∆∞ m·ª•c checkpoint
    os.makedirs(args.ckpt_dir, exist_ok=True)
    logger.info(f"Th∆∞ m·ª•c checkpoint: {args.ckpt_dir}")

    # T·∫°o th∆∞ m·ª•c visualization v·ªõi run_id
    run_id = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join("/content/drive/MyDrive/Goemotions/visualizations", f"{args.model_type}_{run_id}")
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Th∆∞ m·ª•c visualization: {output_dir}")

    # K√≠ch ho·∫°t mixed precision
    set_global_policy('mixed_float16')

    # T·∫£i d·ªØ li·ªáu
    train_texts, val_texts, test_texts, train_labels, val_labels, test_labels, label_list = load_data(args, logger)
    num_labels = len(label_list)

    # M√£ h√≥a nh√£n
    y_train = to_multi_hot(train_labels, num_labels)
    y_val = to_multi_hot(val_labels, num_labels)
    y_test = to_multi_hot(test_labels, num_labels)

    # Token h√≥a v√† padding
    vocab_size = 30000
    max_len = 100
    embedding_dim = 300

    tokenizer = Tokenizer(num_words=vocab_size, oov_token="<OOV>", filters='!"$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n')
    tokenizer.fit_on_texts(train_texts)

    with open(os.path.join(args.ckpt_dir, f"{args.model_type}_tokenizer.pkl"), "wb") as f:
        pickle.dump(tokenizer, f)
    logger.info(f"ƒê√£ l∆∞u tokenizer v√†o {args.model_type}_tokenizer.pkl")

    def encode(texts):
        seq = tokenizer.texts_to_sequences(texts)
        return pad_sequences(seq, maxlen=max_len, padding="post", truncating="post")

    X_train = encode(train_texts)
    X_val = encode(val_texts)
    X_test = encode(test_texts)

    # T·∫£i v√† t·∫°o ma tr·∫≠n embedding
    glove_embeddings, emoji2vec = load_embeddings(args.glove_path, args.emoji2vec_path, vocab_size, embedding_dim, logger)
    embedding_matrix = create_embedding_matrix(tokenizer, glove_embeddings, emoji2vec, vocab_size, embedding_dim, logger)

    # Thi·∫øt l·∫≠p TPU ho·∫∑c GPU
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
        logger.info("Ch·∫°y tr√™n TPU")
    except ValueError:
        strategy = tf.distribute.get_strategy()
        logger.info("Ch·∫°y tr√™n GPU/CPU")

    with strategy.scope():
        # X√¢y d·ª±ng m√¥ h√¨nh BiLSTM
        inputs = Input(shape=(max_len,), dtype="int32")
        embed = Embedding(vocab_size, embedding_dim, weights=[embedding_matrix], trainable=True)(inputs)

        # L·ªõp BiLSTM ƒë·∫ßu ti√™n
        bilstm_out = Bidirectional(LSTM(128, return_sequences=True))(embed)
        bilstm_out = Dropout(0.5)(bilstm_out)
        # L·ªõp BiLSTM th·ª© hai
        bilstm_out = Bidirectional(LSTM(64))(bilstm_out)
        drop = Dropout(0.5)(bilstm_out)
        output = Dense(num_labels, activation="sigmoid")(drop)

        model = Model(inputs, output)

        # Bi√™n d·ªãch m√¥ h√¨nh
        model.compile(
            optimizer=Adam(learning_rate=5e-4),
            loss=tf.keras.losses.BinaryCrossentropy(),
            metrics=["accuracy"]
        )

    # Thi·∫øt l·∫≠p callbacks (kh√¥ng c√≥ EarlyStopping)
    checkpoint_path = os.path.join(args.ckpt_dir, f"{args.model_type}_model_{{epoch:02d}}_{{val_accuracy:.4f}}.keras")
    checkpoint = ModelCheckpoint(
        checkpoint_path,
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    )
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=2,
        min_lr=1e-6
    )

    # Hu·∫•n luy·ªán
    batch_size = 128
    epochs = 30
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        batch_size=batch_size,
        epochs=epochs,
        class_weight=compute_class_weights(train_labels, num_labels, logger),
        callbacks=[checkpoint, reduce_lr]
    )

    # ƒê√°nh gi√°
    loss, acc = model.evaluate(X_test, y_test, batch_size=batch_size)
    logger.info(f"Test Loss: {loss:.4f}, Test Accuracy: {acc:.4f}")

    # T·ªëi ∆∞u h√≥a ng∆∞·ª°ng tr√™n t·∫≠p validation
    val_pred_probs = model.predict(X_val, batch_size=batch_size)
    best_threshold = optimize_threshold(y_val, val_pred_probs, logger)
    pred_probs = model.predict(X_test, batch_size=batch_size)
    y_pred = (pred_probs >= best_threshold).astype(int)

    # B√°o c√°o k·∫øt qu·∫£
    raw_acc = accuracy_score(y_test.flatten(), y_pred.flatten())
    micro_f1 = f1_score(y_test, y_pred, average='micro', zero_division=0)
    macro_f1 = f1_score(y_test, y_pred, average='macro', zero_division=0)
    weighted_f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
    report = classification_report(y_test, y_pred, target_names=label_list, zero_division=0, digits=4)

    logger.info(f"\nTest Loss: {loss:.4f}")
    logger.info(f"Keras Accuracy (element-wise): {acc:.4f}")
    logger.info(f"Raw Accuracy (sklearn, element-wise): {raw_acc:.4f}")
    logger.info(f"Micro-averaged F1 score: {micro_f1:.4f}")
    logger.info(f"Macro-averaged F1 score: {macro_f1:.4f}")
    logger.info(f"Weighted-averaged F1 score: {weighted_f1:.4f}")
    logger.info(f"\nClassification Report (threshold {best_threshold}):\n" + report)

    # T·∫°o v√† l∆∞u c√°c bi·ªÉu ƒë·ªì
    plot_confusion_matrices(y_test, y_pred, label_list, output_dir, args.model_type, logger)
    plot_aggregated_confusion_matrix(y_test, y_pred, output_dir, args.model_type, logger)
    plot_performance_metrics_heatmap(y_test, y_pred, label_list, output_dir, args.model_type, logger)
    plot_loss(history, output_dir, args.model_type, logger)

    # Ph√¢n t√≠ch l·ªói
    errors = []
    for i, (true, pred, text) in enumerate(zip(y_test, y_pred, test_texts)):
        if not np.array_equal(true, pred):
            errors.append((text, true, pred))
    logger.info(f"C√°c m·∫´u d·ª± ƒëo√°n sai (top 10): {errors[:10]}")

    # D·ª± ƒëo√°n tr√™n v√≠ d·ª• m·ªõi
    test_examples = [
        "Feeling on top of the world today! üéâüòä #BestDayEver #SoHappy",
        "Totally let down... üò¢üíî #Disappointed #WhyThis",
        "Omg that‚Äôs incredible news! üòç‚ú® #Amazing #Grateful",
        "Head full of thoughts rn ü§Øü§î #Confused #Overthinking",
        "Still can‚Äôt believe this happened. So pissed! üò°üî• #Angry #Unbelievable"
    ]
    results = predict_pipeline(model, tokenizer, test_examples, label_list, max_len, best_threshold, logger)

    # L∆∞u k·∫øt qu·∫£
    output_file = os.path.join(output_dir, f"BiLSTM-results_{run_id}.txt")
    with open(output_file, "a", encoding="utf-8") as f:
        f.write(f"\n=== BiLSTM Results - {args.model_type} - {pd.Timestamp.now()} ===\n")
        f.write(f"Test Loss: {loss:.4f}\n")
        f.write(f"Keras Accuracy (element-wise): {acc:.4f}\n")
        f.write(f"Raw Accuracy (sklearn, element-wise): {raw_acc:.4f}\n")
        f.write(f"Micro-averaged F1 score: {micro_f1:.4f}\n")
        f.write(f"Macro-averaged F1 score: {macro_f1:.4f}\n")
        f.write(f"Weighted-averaged F1 score: {weighted_f1:.4f}\n")
        f.write(f"\nClassification Report (threshold {best_threshold}):\n" + report + "\n")
        f.write("\nExample Predictions:\n")
        for i, result in enumerate(results):
            f.write(f"Example {i+1}: {result['text']}\n")
            f.write(f"Predicted Labels: {result['labels']}\n")
            f.write(f"Top-3 Labels and Probabilities: {result['top_labels']}\n\n")
        f.write("================================\n")

    # L∆∞u m√¥ h√¨nh v√† tokenizer
    logger.info(f"ƒêang l∆∞u m√¥ h√¨nh v√† tokenizer")
    local_model_dir = os.path.join(args.ckpt_dir, f"{args.model_type}")
    os.makedirs(local_model_dir, exist_ok=True)
    model.save(os.path.join(local_model_dir, f"{args.model_type}_model.keras"))
    with open(os.path.join(local_model_dir, f"{args.model_type}_tokenizer.pkl"), "wb") as f:
        pickle.dump(tokenizer, f)

    with open(os.path.join(local_model_dir, f"{args.model_type}_labels.txt"), "w", encoding="utf-8") as f:
        for label in label_list:
            f.write(f"{label}\n")

    config = {
        "vocab_size": vocab_size,
        "max_len": max_len,
        "embedding_dim": embedding_dim,
        "num_labels": num_labels,
        "bilstm_units": [128, 64],
        "dropout_rate": 0.5,
        "learning_rate": 5e-4,
        "best_threshold": best_threshold,
        "model_type": args.model_type
    }
    with open(os.path.join(local_model_dir, f"{args.model_type}_config.json"), "w") as f:
        import json
        json.dump(config, f, indent=4)

    logger.info(f"ƒê√£ l∆∞u m√¥ h√¨nh v√† tokenizer t·∫°i: {local_model_dir}")

if __name__ == "__main__":
    main()