In [None]:
!pip install torch==2.3.0 transformers==4.38.2 numpy==1.26.4 gensim emoji wordninja scikit-learn attrdict
!pip install torchviz



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class AsymmetricLossOptimized(torch.nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-8):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
    def forward(self, logits, target):
        x_sigmoid = torch.sigmoid(logits)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)
        loss_pos = target * torch.log(xs_pos.clamp(min=self.eps))
        loss_neg = (1 - target) * torch.log(xs_neg.clamp(min=self.eps))
        loss = loss_pos + loss_neg
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            pt0 = xs_pos * target
            pt1 = xs_neg * (1 - target)
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * target + self.gamma_neg * (1 - target)
            loss *= (1 - pt) ** one_sided_gamma
        return -loss.sum(dim=1).mean()

In [None]:
import sys
import json
import logging
import os
import glob
import torch
import numpy as np
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, TensorDataset
from transformers import BertConfig, BertTokenizer, BertPreTrainedModel, BertModel
from transformers.optimization import get_linear_schedule_with_warmup
from torch.optim import AdamW
from tqdm import tqdm
from statistics import mean
import emoji
import gensim
import re
import string
import wordninja
from typing import Optional, Tuple, List, Dict, Union
from transformers import Pipeline, PreTrainedTokenizer, ModelCard, PreTrainedModel
from transformers.pipelines import ArgumentHandler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from collections import defaultdict, Counter
import copy
import torch.amp

# L·ªãch s·ª≠ hu·∫•n luy·ªán v√† ƒë√°nh gi√°
train_history = []
train_history_epochs = []
val_history = []
val_history_epochs = []
test_history = []

# Thi·∫øt l·∫≠p logger
logger = logging.getLogger(__name__)

# B·∫£n ƒë·ªì vi·∫øt t·∫Øt
contraction_mapping = {
    "ain't": "is not", "aren't": "are not", "can't": "cannot", "'cause": "because",
    "could've": "could have", "couldn't": "could not", "didn't": "did not",
    "doesn't": "does not", "don't": "do not", "hadn't": "had not",
    "hasn't": "has not", "haven't": "have not", "he'd": "he would",
    "he'll": "he will", "he's": "he is", "how'd": "how did",
    "how'd'y": "how do you", "how'll": "how will", "how's": "how is",
    "I'd": "I would", "I'd've": "I would have", "I'll": "I will",
    "I'll've": "I will have", "I'm": "I am", "I've": "I have",
    "i'd": "i would", "i'd've": "i would have", "i'll": "i will",
    "i'll've": "i will have", "i'm": "i am", "i've": "i have",
    "isn't": "is not", "it'd": "it would", "it'd've": "it would have",
    "it'll": "it will", "it'll've": "it will have", "it's": "it is",
    "let's": "let us", "ma'am": "madam", "mayn't": "may not",
    "might've": "might have", "mightn't": "might not",
    "mightn't've": "might not have", "must've": "must have",
    "mustn't": "must not", "mustn't've": "must not have",
    "needn't": "need not", "needn't've": "need not have",
    "o'clock": "of the clock", "oughtn't": "ought not",
    "oughtn't've": "ought not have", "shan't": "shall not",
    "sha'n't": "shall not", "shan't've": "shall not have",
    "she'd": "she would", "she'd've": "she would have",
    "she'll": "she will", "she'll've": "she will have",
    "she's": "she is", "should've": "should have", "shouldn't": "should not",
    "shouldn't've": "should not have", "so've": "so have", "so's": "so as",
    "this's": "this is", "that'd": "that would", "that'd've": "that would have",
    "that's": "that is", "there'd": "there would", "there'd've": "there would have",
    "there's": "there is", "here's": "here is", "they'd": "they would",
    "they'd've": "they would have", "they'll": "they will",
    "they'll've": "they will have", "they're": "they are",
    "they've": "they have", "to've": "to have", "wasn't": "was not",
    "we'd": "we would", "we'd've": "we would have", "we'll": "we will",
    "we'll've": "we will have", "we're": "we are", "we've": "we have",
    "weren't": "were not", "what'll": "what will", "what'll've": "what will have",
    "what're": "what are", "what's": "what is", "what've": "what have",
    "when's": "when is", "when've": "when have", "where'd": "where did",
    "where's": "where is", "where've": "where have", "who'll": "who will",
    "who'll've": "who will have", "who's": "who is", "who've": "who have",
    "why's": "why is", "why've": "why have", "will've": "will have",
    "won't": "will not", "won't've": "will not have", "would've": "would have",
    "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all",
    "y'all'd": "you all would", "y'all'd've": "you all would have",
    "y'all're": "you all are", "y'all've": "you all have", "you'd": "you would",
    "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have",
    "you're": "you are", "you've": "you have", "u.s": "america", "e.g": "for example"
}

# B·∫£n ƒë·ªì d·∫•u c√¢u
punct_mapping = {
    "‚Äò": "'", "‚Çπ": "e", "¬¥": "'", "¬∞": "", "‚Ç¨": "e", "‚Ñ¢": "tm", "‚àö": " sqrt ",
    "√ó": "x", "¬≤": "2", "‚Äî": "-", "‚Äì": "-", "‚Äô": "'", "_": "-", "`": "'", '‚Äú': '"',
    '‚Äù': '"', '‚Äú': '"', "¬£": "e", '‚àû': 'infinity', 'Œ∏': 'theta', '√∑': '/', 'Œ±': 'alpha',
    '‚Ä¢': '.', '√†': 'a', '‚àí': '-', 'Œ≤': 'beta', '‚àÖ': '', '¬≥': '3', 'œÄ': 'pi'
}

# T·ª´ ƒëi·ªÉn l·ªói ch√≠nh t·∫£
mispell_dict = {
    'colour': 'color', 'centre': 'center', 'favourite': 'favorite', 'travelling': 'traveling',
    'counselling': 'counseling', 'theatre': 'theater', 'cancelled': "canceled", 'labour': 'labor',
    'organisation': 'organization', 'wwii': 'world war 2', 'citicise': 'criticize', 'youtu ': 'youtube ',
    'Qoura': 'Quora', 'sallary': 'salary', 'Whta': 'What', 'narcisist': 'narcissist', 'howdo': 'how do',
    'whatare': 'what are', 'howcan': 'how can', 'howmuch': 'how much', 'howmany': 'how many',
    'whydo': 'why do', 'doI': 'do I', 'theBest': 'the best', 'howdoes': 'how does',
    'mastrubation': 'masturbation', 'mastrubate': 'masturbate', 'mastrubating': 'masturbating',
    'pennis': 'penis', 'Etherium': 'Ethereum', 'narcissit': 'narcissist', 'bigdata': 'big data',
    '2k17': '2017', '2k18': '2018', 'qouta': 'quota', 'exboyfriend': 'ex boyfriend',
    'airhostess': 'air hostess', 'whst': 'what', 'watsapp': 'whatsapp',
    'demonitisation': 'demonetization', 'demonitization': 'demonetization',
    'demonetisation': 'demonetization'
}

# D·∫•u c√¢u c·∫ßn lo·∫°i b·ªè (gi·ªØ l·∫°i ! v√† ? ƒë·ªÉ b·∫£o to√†n ng·ªØ nghƒ©a c·∫£m x√∫c)
punct_chars = list((set(string.punctuation) | {
    "‚Äô", "‚Äò", "‚Äì", "‚Äî", "~", "|", "‚Äú", "‚Äù", "‚Ä¶", "'", "`", "_", "‚Äú"
}) - set(["#", "!", "?"]))
punct_chars.sort()
punctuation = "".join(punct_chars)
replace = re.compile("[%s]" % re.escape(punctuation))

# ----------- LO·∫†I B·ªé AttrDict, KH√îNG D√ôNG ARGPARSE ------------

def init_logger(args):
    os.makedirs(args.output_dir, exist_ok=True)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.DEBUG,
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(os.path.join(args.output_dir, "training.log"))
        ]
    )
    logging.getLogger().setLevel(logging.DEBUG)
    sys.stdout.flush()
    logger.info("Logger ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o.")

def set_seed(args):
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

class BertForMultiLabelClassification(BertPreTrainedModel):
    """M√¥ h√¨nh BERT cho ph√¢n lo·∫°i ƒëa nh√£n, t√≠ch h·ª£p Emoji2Vec"""
    def __init__(self, config, emoji2vec_path: Optional[str] = None, emoji_dim: int = 300):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.emoji_dim = emoji_dim
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(0.2)
        self.emoji2vec = self.load_emoji2vec(emoji2vec_path) if emoji2vec_path else None
        self.emoji_projection = torch.nn.Linear(emoji_dim, config.hidden_size) if self.emoji2vec else None
        self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
        self.relu = torch.nn.ReLU()
        self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
        self.init_weights()
        if self.emoji2vec:
            logger.info(f"ƒê√£ t·∫£i emoji2vec v·ªõi {len(self.emoji2vec)} vector")

    def load_emoji2vec(self, emoji2vec_path: str) -> gensim.models.KeyedVectors:
        """T·∫£i m√¥ h√¨nh Emoji2Vec t·ª´ file"""
        if not os.path.exists(emoji2vec_path):
            logger.warning(f"Kh√¥ng t√¨m th·∫•y file Emoji2Vec t·∫°i {emoji2vec_path}, ti·∫øp t·ª•c kh√¥ng d√πng Emoji2Vec")
            return None
        try:
            with open(emoji2vec_path, 'r', encoding='utf-8') as f:
                first_line = f.readline()
                if not first_line.strip():
                    raise ValueError(f"File Emoji2Vec {emoji2vec_path} r·ªóng")
            return gensim.models.KeyedVectors.load_word2vec_format(
                emoji2vec_path, binary=False, unicode_errors='ignore'
            )
        except Exception as e:
            logger.error(f"L·ªói khi t·∫£i Emoji2Vec t·ª´ {emoji2vec_path}: {str(e)}")
            return None

    def get_emoji_embedding(self, emoji_tokens: List[List[str]], device: torch.device) -> torch.Tensor:
        """T·∫°o embedding cho c√°c bi·ªÉu t∆∞·ª£ng c·∫£m x√∫c"""
        if not self.emoji2vec or not emoji_tokens:
            return torch.zeros((len(emoji_tokens), self.emoji_dim), dtype=torch.float, device=device)
        batch_embeddings = []
        for tokens in emoji_tokens:
            if not tokens:
                batch_embeddings.append(torch.zeros(self.emoji_dim, dtype=torch.float, device=device))
                continue
            embeddings = []
            for token in tokens:
                if token in self.emoji2vec:
                    emb = torch.tensor(self.emoji2vec[token], dtype=torch.float, device=device)
                else:
                    emb = torch.zeros(self.emoji_dim, dtype=torch.float, device=device)
                embeddings.append(emb)
            if embeddings:
                embeddings = torch.stack(embeddings).mean(dim=0)
            else:
                embeddings = torch.zeros(self.emoji_dim, dtype=torch.float, device=device)
            batch_embeddings.append(embeddings)
        return torch.stack(batch_embeddings)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        emoji_tokens: Optional[List[List[str]]] = None,
        class_weights: Optional[torch.Tensor] = None
    ) -> Tuple:
        """H√†m forward c·ªßa m√¥ h√¨nh"""
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds
        )
        pooled_output = outputs[1]
        if self.emoji2vec and emoji_tokens is not None:
            emoji_embeds = self.get_emoji_embedding(emoji_tokens, pooled_output.device)
            if emoji_embeds is not None:
                emoji_embeds = self.emoji_projection(emoji_embeds)
                pooled_output = pooled_output + emoji_embeds
        pooled_output = self.dropout(pooled_output)
        pooled_output = self.dense(pooled_output)
        pooled_output = self.relu(pooled_output)
        logits = self.classifier(pooled_output)
        outputs = (logits,) + outputs[2:]
        # if labels is not None:
        #     loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights)
        #     loss = loss_fct(logits, labels)
        #     outputs = (loss,) + outputs
        if labels is not None:
              loss_fct = AsymmetricLossOptimized(gamma_neg=4, gamma_pos=0, clip=0.05)
              loss = loss_fct(logits, labels)
              outputs = (loss,) + outputs
        return outputs

class MultiLabelPipeline(Pipeline):
    """Pipeline ƒë·ªÉ d·ª± ƒëo√°n nh√£n t·ª´ vƒÉn b·∫£n ƒë·∫ßu v√†o"""
    def __init__(
        self,
        model: Union[PreTrainedModel, 'TFPreTrainedModel'],
        tokenizer: PreTrainedTokenizer,
        modelcard: Optional[ModelCard] = None,
        framework: Optional[str] = None,
        task: str = "",
        args_parser: ArgumentHandler = None,
        device: int = -1,
        binary_output: bool = False,
        threshold: float = 0.3,
        emoji2vec_path: Optional[str] = None
    ):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            modelcard=modelcard,
            framework=framework,
            args_parser=args_parser,
            device=device if device >= 0 else None,
            binary_output=binary_output,
            task=task
        )
        self.threshold = threshold
        self.emoji2vec_path = emoji2vec_path
        self.emoji2vec = self.load_emoji2vec(emoji2vec_path) if emoji2vec_path else None
        self._device = torch.device("cuda" if torch.cuda.is_available() and device >= 0 else "cpu")

    def load_emoji2vec(self, emoji2vec_path: str) -> gensim.models.KeyedVectors:
        """T·∫£i Emoji2Vec cho pipeline"""
        if not os.path.exists(emoji2vec_path):
            logger.warning(f"Kh√¥ng t√¨m th·∫•y file Emoji2Vec t·∫°i {emoji2vec_path}")
            return None
        try:
            with open(emoji2vec_path, 'r', encoding='utf-8') as f:
                first_line = f.readline()
                if not first_line.strip():
                    raise ValueError(f"File Emoji2Vec {emoji2vec_path} r·ªóng")
            return gensim.models.KeyedVectors.load_word2vec_format(
                emoji2vec_path, binary=False, unicode_errors='ignore'
            )
        except Exception as e:
            logger.error(f"L·ªói khi t·∫£i Emoji2Vec t·ª´ {emoji2vec_path}: {str(e)}")
            return None

    def _extract_emoji_tokens(self, text: str) -> List[str]:
        """Tr√≠ch xu·∫•t c√°c bi·ªÉu t∆∞·ª£ng c·∫£m x√∫c t·ª´ vƒÉn b·∫£n"""
        if not isinstance(text, str) or not text:
            return []
        return [item["emoji"] for item in emoji.emoji_list(text)]

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, inputs, **kwargs):
        """Ti·ªÅn x·ª≠ l√Ω vƒÉn b·∫£n ƒë·∫ßu v√†o"""
        if isinstance(inputs, str):
            inputs = [inputs]
        model_inputs = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128,
            return_attention_mask=True,
            return_token_type_ids=True
        )
        emoji_tokens_list = None
        if self.emoji2vec:
            emoji_tokens_list = [self._extract_emoji_tokens(text) for text in inputs]
            emoji_tokens_list = [tokens if tokens else [] for tokens in emoji_tokens_list]
        model_inputs = {k: v.to(self._device) for k, v in model_inputs.items()}
        return {"inputs": model_inputs, "emoji_tokens": emoji_tokens_list}

    def _forward(self, model_inputs, **kwargs):
        """Ch·∫°y m√¥ h√¨nh v·ªõi ƒë·∫ßu v√†o"""
        inputs = model_inputs["inputs"]
        emoji_tokens = model_inputs.get("emoji_tokens", [[] for _ in range(inputs["input_ids"].size(0))])
        try:
            outputs = self.model(**inputs, emoji_tokens=emoji_tokens)
        except TypeError:
            outputs = self.model(**inputs)
        return {"logits": outputs[0]}

    def postprocess(self, model_outputs, **kwargs):
        """X·ª≠ l√Ω h·∫≠u k·ª≥ ƒë·ªÉ tr·∫£ v·ªÅ nh√£n v√† x√°c su·∫•t"""
        logits = model_outputs["logits"].cpu().numpy()
        scores = 1 / (1 + np.exp(-logits))
        results = []
        for item in scores:
            labels = []
            confidences = []
            for idx, s in enumerate(item):
                if s > self.threshold:
                    label = self.model.config.id2label.get(idx, f"label_{idx}")
                    labels.append(label)
                    confidences.append(float(s))
            results.append({"labels": labels, "scores": confidences})
        return results

def CleanText(text: str) -> List[str]:
    """Ti·ªÅn x·ª≠ l√Ω vƒÉn b·∫£n: x·ª≠ l√Ω emoji, vi·∫øt t·∫Øt, d·∫•u c√¢u, hashtag, l·ªói ch√≠nh t·∫£"""
    if not isinstance(text, str) or not text:
        return []
    logger.debug(f"VƒÉn b·∫£n g·ªëc: {text}")
    emoji_dict = {}
    def replace_emoji(emoji: str, data: dict) -> str:
        placeholder = f"__EMOJI_{len(emoji_dict)}__"
        emoji_dict[placeholder] = emoji
        return placeholder
    text = emoji.replace_emoji(text, replace_emoji)
    for contraction, full_form in contraction_mapping.items():
        text = text.replace(contraction, full_form)
    for p, replacement in punct_mapping.items():
        text = text.replace(p, replacement)
    def split_hashtag(match):
        hashtag = match.group(0)[1:]
        words = wordninja.split(hashtag)
        return ' '.join(words)
    text = re.sub(r"#\w+", split_hashtag, text)
    text = re.sub(r"http\S*|\S*\.com\S*|\S*www\S*", " ", text)
    text = re.sub(r"\s@\S+", " ", text)
    text = replace.sub(" ", text)
    for placeholder, emoji_text in emoji_dict.items():
        text = text.replace(placeholder, emoji_text)
    text = text.lower()
    words = text.split()
    words = [mispell_dict.get(word, word) for word in words]
    text = ' '.join(words)
    text = re.sub(r"\s+", " ", text).strip()
    words = text.split()
    cleaned_words = [w for w in words if len(w) > 0]
    logger.debug(f"VƒÉn b·∫£n ƒë√£ x·ª≠ l√Ω: {text}")
    logger.debug(f"T·ª´ ƒë√£ x·ª≠ l√Ω: {cleaned_words}")
    return cleaned_words

class InputExample:
    """L·ªõp l∆∞u tr·ªØ m·ªôt m·∫´u d·ªØ li·ªáu"""
    def __init__(self, guid, text_a, text_b, label, emoji_tokens=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label
        self.emoji_tokens = emoji_tokens if emoji_tokens is not None else []

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        return copy.deepcopy(self.__dict__)

    def to_json_string(self):
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

class InputFeatures:
    """L·ªõp l∆∞u tr·ªØ ƒë·∫∑c tr∆∞ng c·ªßa m·ªôt m·∫´u d·ªØ li·ªáu"""
    def __init__(self, input_ids, attention_mask, token_type_ids, label, emoji_tokens=None, guid=None):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label = label
        self.emoji_tokens = emoji_tokens if emoji_tokens is not None else []
        self.guid = guid

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        return copy.deepcopy(self.__dict__)

    def to_json_string(self):
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

class GoEmotionsProcessor:
    """X·ª≠ l√Ω d·ªØ li·ªáu GoEmotions"""
    def __init__(self, args):
        self.args = args

    def get_labels(self):
        """ƒê·ªçc danh s√°ch nh√£n t·ª´ file"""
        labels = []
        label_file = os.path.join(self.args.data_dir, self.args.label_file)
        if not os.path.exists(label_file):
            logger.error(f"Kh√¥ng t√¨m th·∫•y file nh√£n t·∫°i {label_file}")
            raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y file nh√£n t·∫°i {label_file}")
        try:
            with open(label_file, "r", encoding="utf-8") as f:
                for line in f:
                    labels.append(line.rstrip())
            logger.info(f"ƒê√£ ƒë·ªçc {len(labels)} nh√£n t·ª´ {label_file}")
        except Exception as e:
            logger.error(f"L·ªói khi ƒë·ªçc file nh√£n {label_file}: {str(e)}")
            raise
        return labels

    def _read_file(self, input_file):
        """ƒê·ªçc file d·ªØ li·ªáu"""
        if not os.path.exists(input_file):
            logger.error(f"Kh√¥ng t√¨m th·∫•y file d·ªØ li·ªáu t·∫°i {input_file}")
            raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y file d·ªØ li·ªáu t·∫°i {input_file}")
        try:
            with open(input_file, "r", encoding="utf-8") as f:
                lines = f.readlines()
            logger.info(f"ƒê√£ ƒë·ªçc {len(lines)} d√≤ng t·ª´ file {input_file}")
            return lines
        except Exception as e:
            logger.error(f"L·ªói khi ƒë·ªçc file {input_file}: {str(e)}")
            raise

    def _create_examples(self, lines, set_type):
        """T·∫°o c√°c m·∫´u InputExample t·ª´ d·ªØ li·ªáu"""
        examples = []
        label_list_len = len(self.get_labels())
        label_counts = Counter()
        for i, line in enumerate(lines):
            guid = f"{set_type}-{i}"
            line = line.strip()
            if not line:
                logger.warning(f"D√≤ng {i} r·ªóng. B·ªè qua.")
                continue
            items = line.split("\t")
            if len(items) < 1:
                logger.warning(f"D√≤ng {i} kh√¥ng h·ª£p l·ªá: {line}")
                continue
            raw_text = items[0]
            label = [0] if len(items) <= 1 else items[1].split(",")
            try:
                label = [int(round(float(l.strip()))) for l in label if l.strip().replace('.', '', 1).isdigit()]
                label = [l for l in label if 0 <= l < label_list_len]
                if not label:
                    logger.warning(f"Kh√¥ng c√≥ nh√£n h·ª£p l·ªá ·ªü d√≤ng {i}: {items[1]}. B·ªè qua m·∫´u.")
                    continue
                label_counts.update(label)
            except (ValueError, IndexError) as e:
                logger.warning(f"Nh√£n kh√¥ng h·ª£p l·ªá ·ªü d√≤ng {i}: {items[1]}. B·ªè qua m·∫´u. L·ªói: {e}")
                continue
            cleaned_words = CleanText(raw_text)
            cleaned_text = ' '.join(cleaned_words)
            emoji_tokens = [t for t in raw_text if emoji.is_emoji(t)]
            examples.append(InputExample(
                guid=guid,
                text_a=cleaned_text,
                text_b=None,
                label=label,
                emoji_tokens=emoji_tokens
            ))
        logger.info(f"ƒê√£ t·∫°o {len(examples)} m·∫´u t·ª´ {set_type}")
        logger.info(f"Ph√¢n b·ªë nh√£n cho {set_type}: {dict(label_counts)}")
        if len(examples) == 0:
            logger.error(f"Kh√¥ng c√≥ m·∫´u n√†o ƒë∆∞·ª£c t·∫°o t·ª´ {set_type}. Ki·ªÉm tra file d·ªØ li·ªáu!")
        return examples

    def get_examples(self, mode):
        """L·∫•y m·∫´u d·ªØ li·ªáu theo ch·∫ø ƒë·ªô (train/dev/test)"""
        file_to_read = {'train': self.args.train_file, 'dev': self.args.dev_file, 'test': self.args.test_file}.get(mode)
        if not file_to_read:
            raise ValueError("Ch·∫ø ƒë·ªô ph·∫£i l√† 'train', 'dev', ho·∫∑c 'test'")
        file_path = os.path.join(self.args.data_dir, file_to_read)
        logger.info(f"ƒê·ªçc d·ªØ li·ªáu {mode} t·ª´ {file_path}")
        return self._create_examples(self._read_file(file_path), mode)

def convert_examples_to_features(args, examples, tokenizer, max_length):
    """Chuy·ªÉn ƒë·ªïi InputExample th√†nh InputFeatures"""
    processor = GoEmotionsProcessor(args)
    label_list_len = len(processor.get_labels())

    def convert_to_one_hot_label(label):
        """Chuy·ªÉn nh√£n th√†nh d·∫°ng one-hot"""
        one_hot_label = [0] * label_list_len
        for l in label:
            if 0 <= l < label_list_len:
                one_hot_label[l] = 1
            else:
                logger.warning(f"Ch·ªâ s·ªë nh√£n kh√¥ng h·ª£p l·ªá {l} b·ªã b·ªè qua")
        return one_hot_label

    labels = [convert_to_one_hot_label(example.label) for example in examples]
    logger.info(f"ƒêang m√£ h√≥a {len(examples)} m·∫´u v·ªõi max_length={max_length}")
    try:
        batch_encoding = tokenizer.batch_encode_plus(
            [(example.text_a, example.text_b) for example in examples],
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=True
        )
    except Exception as e:
        logger.error(f"L·ªói khi m√£ h√≥a: {str(e)}")
        return None

    features = []
    for i in range(len(examples)):
        inputs = {k: batch_encoding[k][i] for k in batch_encoding}
        feature = InputFeatures(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            token_type_ids=inputs.get('token_type_ids'),
            label=labels[i],
            emoji_tokens=examples[i].emoji_tokens,
            guid=examples[i].guid
        )
        features.append(feature)

    logger.info("*** 5 M·∫´u ƒê·∫ßu Ti√™n ***")
    for i, example in enumerate(examples[:5]):
        logger.info("*** Example ***")
        logger.info(f"guid: {example.guid}")
        logger.info(f"sentence: {example.text_a}")
        logger.info(f"tokens: {tokenizer.convert_ids_to_tokens(batch_encoding['input_ids'][i])}")
        logger.info(f"input_ids: {' '.join(map(str, batch_encoding['input_ids'][i]))}")
        logger.info(f"attention_mask: {' '.join(map(str, batch_encoding['attention_mask'][i]))}")
        logger.info(f"token_type_ids: {' '.join(map(str, batch_encoding['token_type_ids'][i]))}")
        logger.info(f"label: {' '.join(map(str, labels[i]))}")
        logger.info("")

    return features

def compute_class_weights(labels):
    """T√≠nh tr·ªçng s·ªë l·ªõp ƒë·ªÉ x·ª≠ l√Ω m·∫•t c√¢n b·∫±ng (logarithm, c√≥ clip)"""
    label_counts = np.sum(labels, axis=0)
    total_samples = len(labels)
    # T√≠nh weight theo log ƒë·ªÉ tr√°nh weight qu√° l·ªõn v·ªõi nh√£n c·ª±c hi·∫øm
    weights = np.log(total_samples / (label_counts + 1))
    # Clip l·∫°i ƒë·ªÉ tr√°nh qu√° l·ªõn/qu√° nh·ªè (·ªïn ƒë·ªãnh train)
    weights = np.clip(weights, 0.5, 10.0)
    return torch.tensor(weights, dtype=torch.float)

def load_and_cache_examples(args, tokenizer, mode):
    """T·∫£i v√† l∆∞u tr·ªØ d·ªØ li·ªáu t·ª´ cache ho·∫∑c file"""
    processor = GoEmotionsProcessor(args)
    cached_features_file = os.path.join(
        args.data_dir,
        f"cached_{mode}_{args.task}_{args.max_seq_len}"
    )
    logger.info(f"Ki·ªÉm tra file cache: {cached_features_file}")
    if os.path.exists(cached_features_file):
        logger.info(f"T·∫£i ƒë·∫∑c tr∆∞ng t·ª´ file cache {cached_features_file}")
        features = torch.load(cached_features_file, weights_only=False)
        logger.info("*** 5 M·∫´u ƒê·∫ßu Ti√™n T·ª´ Cache ***")
        for i, f in enumerate(features[:5]):
            logger.info("*** Example ***")
            guid = getattr(f, 'guid', f'unknown-{i}')  # X·ª≠ l√Ω cache c≈© kh√¥ng c√≥ guid
            logger.info(f"guid: {guid}")
            logger.info(f"input_ids: {' '.join(map(str, f.input_ids))}")
            logger.info(f"tokens: {tokenizer.convert_ids_to_tokens(f.input_ids)}")
            logger.info(f"attention_mask: {' '.join(map(str, f.attention_mask))}")
            logger.info(f"token_type_ids: {' '.join(map(str, f.token_type_ids))}")
            logger.info(f"label: {' '.join(map(str, f.label))}")
            logger.info(f"emoji_tokens: {f.emoji_tokens}")
            logger.info("")
    else:
        logger.info(f"T·∫°o ƒë·∫∑c tr∆∞ng t·ª´ file d·ªØ li·ªáu t·∫°i {args.data_dir}")
        examples = processor.get_examples(mode)
        if not examples:
            logger.error(f"Kh√¥ng c√≥ m·∫´u n√†o t·ª´ {mode}. Ki·ªÉm tra file d·ªØ li·ªáu!")
            return None, None
        features = convert_examples_to_features(args, examples, tokenizer, args.max_seq_len)
        if features is None:
            logger.error(f"Kh√¥ng th·ªÉ t·∫°o ƒë·∫∑c tr∆∞ng cho {mode}. Ki·ªÉm tra l·ªói m√£ h√≥a!")
            return None, None
        logger.info(f"L∆∞u ƒë·∫∑c tr∆∞ng v√†o {cached_features_file}")
        torch.save(features, cached_features_file)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
    all_emoji_tokens = [f.emoji_tokens for f in features]
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    logger.info(f"ƒê√£ t·∫°o dataset v·ªõi {len(dataset)} m·∫´u t·ª´ {mode}")
    return dataset, all_emoji_tokens

def evaluate(args, model, eval_dataset, eval_emoji_tokens, mode, global_step=None, save_to_file=False):
    """ƒê√°nh gi√° m√¥ h√¨nh tr√™n t·∫≠p d·ªØ li·ªáu"""
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
    logger.info(f"***** ƒê√°nh gi√° tr√™n t·∫≠p {mode} {'t·∫°i b∆∞·ªõc ' + str(global_step) if global_step else ''} *****")
    logger.info(f" S·ªë m·∫´u = {len(eval_dataset)}")
    logger.info(f" K√≠ch th∆∞·ªõc batch = {args.eval_batch_size}")
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    for batch in tqdm(eval_dataloader, desc="ƒê√°nh gi√°"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
        batch_emoji_tokens = eval_emoji_tokens[nb_eval_steps * args.eval_batch_size:(nb_eval_steps + 1) * args.eval_batch_size]
        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "labels": batch[3],
            "emoji_tokens": batch_emoji_tokens
        }
        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', enabled=args.device.type == "cuda"):
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]
                eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = inputs["labels"].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
    eval_loss = eval_loss / nb_eval_steps
    thresholds = np.arange(0.1, 0.6, 0.01)
    best_threshold = args.threshold
    best_f1_macro = 0.0
    results = {"loss": eval_loss}
    for threshold in thresholds:
        scores = 1 / (1 + np.exp(-preds))
        binary_preds = (scores > threshold).astype(np.int32)
        metrics = compute_metrics(out_label_ids, binary_preds)
        if metrics["f1_macro"] > best_f1_macro:
            best_f1_macro = metrics["f1_macro"]
            best_threshold = threshold
    scores = 1 / (1 + np.exp(-preds))
    binary_preds = (scores > best_threshold).astype(np.int32)
    metrics = compute_metrics(out_label_ids, binary_preds)
    results = {"loss": eval_loss, "threshold": best_threshold}
    results.update(metrics)

    logger.info(f"K·∫øt qu·∫£ ƒë√°nh gi√° tr√™n t·∫≠p {mode}:")
    for key, value in results.items():
        logger.info(f"  {key} = {value}")

    if save_to_file:
        output_dir = os.path.join(args.output_dir, mode)
        os.makedirs(output_dir, exist_ok=True)
        output_eval_file = os.path.join(output_dir, f"{mode}-{global_step}.txt" if global_step else f"{mode}.txt")
        try:
            with open(output_eval_file, "w") as f:
                for key in sorted(results.keys()):
                    f.write(f"{key} = {results[key]}\n")
                logger.info(f"ƒê√£ l∆∞u k·∫øt qu·∫£ ƒë√°nh gi√° {mode} v√†o {output_eval_file}")
        except Exception as e:
            logger.error(f"L·ªói khi l∆∞u k·∫øt qu·∫£ ƒë√°nh gi√°: {str(e)}")

    logger.info(f"Ng∆∞·ª°ng t·ªët nh·∫•t: {best_threshold}")
    logger.info(f"D·ª± ƒëo√°n m·∫´u (5 m·∫´u ƒë·∫ßu): {binary_preds[:5]}")
    logger.info(f"Nh√£n th·ª±c t·∫ø (5 m·∫´u ƒë·∫ßu): {out_label_ids[:5]}")
    return results

def compute_metrics(true_labels, pred_labels):
    """T√≠nh c√°c ch·ªâ s·ªë ƒë√°nh gi√°"""
    try:
        metrics = {
            "accuracy": accuracy_score(true_labels, pred_labels),
            "f1_macro": f1_score(true_labels, pred_labels, average="macro", zero_division=0),
            "f1_micro": f1_score(true_labels, pred_labels, average="micro", zero_division=0),
            "precision_macro": precision_score(true_labels, pred_labels, average="macro", zero_division=0),
            "precision_micro": precision_score(true_labels, pred_labels, average="micro", zero_division=0),
            "recall_macro": recall_score(true_labels, pred_labels, average="macro", zero_division=0),
            "recall_micro": recall_score(true_labels, pred_labels, average="micro", zero_division=0),
        }
    except Exception as e:
        logger.error(f"L·ªói khi t√≠nh to√°n ch·ªâ s·ªë: {str(e)}")
        metrics = {
            "accuracy": 0.0,
            "f1_macro": 0.0,
            "f1_micro": 0.0,
            "precision_macro": 0.0,
            "precision_micro": 0.0,
            "recall_macro": 0.0,
            "recall_micro": 0.0,
        }
    return metrics

def train(args, model, tokenizer, train_dataset, train_emoji_tokens, dev_dataset=None, dev_emoji_tokens=None, test_dataset=None, test_emoji_tokens=None):
    """Hu·∫•n luy·ªán m√¥ h√¨nh v·ªõi Early Stopping"""
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    labels = train_dataset.tensors[-1].numpy()
    class_weights = compute_class_weights(labels).to(args.device)
    logger.info(f"Tr·ªçng s·ªë l·ªõp: {class_weights}")
    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(t_total * 0.2),
        num_training_steps=t_total
    )
    scaler = torch.amp.GradScaler('cuda') if args.device.type == "cuda" else None
    logger.info("***** B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán *****")
    logger.info(f" S·ªë m·∫´u = {len(train_dataset)}")
    logger.info(f" S·ªë epoch = {args.num_train_epochs}")
    logger.info(f" K√≠ch th∆∞·ªõc batch hu·∫•n luy·ªán = {args.train_batch_size}")
    logger.info(f" S·ªë b∆∞·ªõc t√≠ch l≈©y gradient = {args.gradient_accumulation_steps}")
    logger.info(f" T·ªïng s·ªë b∆∞·ªõc t·ªëi ∆∞u h√≥a = {t_total}")
    global_step = 0
    tr_loss = 0.0
    best_f1_macro = 0.0
    best_model_state = None
    patience = 3  # S·ªë l·∫ßn li√™n ti·∫øp kh√¥ng c·∫£i thi·ªán cho ph√©p
    patience_counter = 0
    stop_training = False

    for epoch in range(int(args.num_train_epochs)):
        if stop_training:
            logger.info(f"Early stopping: D·ª´ng t·∫°i epoch {epoch+1} do F1-macro kh√¥ng c·∫£i thi·ªán {patience} l·∫ßn li√™n ti·∫øp.")
            break
        epoch_loss = []
        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            batch_emoji_tokens = train_emoji_tokens[step * args.train_batch_size:(step + 1) * args.train_batch_size]
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": batch[3],
                "emoji_tokens": batch_emoji_tokens,
                "class_weights": class_weights
            }
            with torch.amp.autocast(device_type='cuda', enabled=args.device.type == "cuda"):
                outputs = model(**inputs)
                loss = outputs[0]
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            tr_loss += loss.item()
            epoch_loss.append(loss.item())
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1
                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if dev_dataset:
                        results = evaluate(args, model, dev_dataset, dev_emoji_tokens, mode="dev", global_step=global_step)
                        logger.info(f"B∆∞·ªõc {global_step} - K·∫øt qu·∫£ validation: {results}")
                    elif test_dataset:
                        results = evaluate(args, model, test_dataset, test_emoji_tokens, mode="test", global_step=global_step)
                        logger.info(f"B∆∞·ªõc {global_step} - K·∫øt qu·∫£ test: {results}")
                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    os.makedirs(output_dir, exist_ok=True)
                    model.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info(f"ƒê√£ l∆∞u checkpoint m√¥ h√¨nh t·∫°i {output_dir}")
                    if dev_dataset:
                        results = evaluate(args, model, dev_dataset, dev_emoji_tokens, mode="dev", global_step=global_step, save_to_file=True)
                        if results.get("f1_macro", 0.0) > best_f1_macro:
                            best_f1_macro = results["f1_macro"]
                            best_model_state = copy.deepcopy(model.state_dict())
                            patience_counter = 0
                            best_output_dir = os.path.join(args.output_dir, "best_checkpoint")
                            os.makedirs(best_output_dir, exist_ok=True)
                            model.save_pretrained(best_output_dir)
                            tokenizer.save_pretrained(best_output_dir)
                            torch.save(args, os.path.join(best_output_dir, "training_args.bin"))
                            logger.info(f"ƒê√£ l∆∞u m√¥ h√¨nh t·ªët nh·∫•t t·∫°i {best_output_dir}")
                        else:
                            patience_counter += 1
                            logger.info(f"F1-macro kh√¥ng c·∫£i thi·ªán ({patience_counter}/{patience})")
                            if patience_counter >= patience:
                                stop_training = True
        train_history_epochs.append(mean(epoch_loss))
        logger.info(f"Epoch {epoch + 1} - Loss trung b√¨nh: {train_history_epochs[-1]:.4f}")
        # Early stopping ki·ªÉm tra sau m·ªói epoch (n·∫øu ch∆∞a stop trong batch)
        if not stop_training and dev_dataset:
            dev_results = evaluate(args, model, dev_dataset, dev_emoji_tokens, mode="dev", global_step=global_step, save_to_file=True)
            logger.info(f"Epoch {epoch + 1} - K·∫øt qu·∫£ validation: {dev_results}")
            f1_macro = dev_results.get("f1_macro", 0.0)
            if f1_macro > best_f1_macro:
                best_f1_macro = f1_macro
                best_model_state = copy.deepcopy(model.state_dict())
                patience_counter = 0
            else:
                patience_counter += 1
                logger.info(f"F1-macro kh√¥ng c·∫£i thi·ªán ({patience_counter}/{patience})")
                if patience_counter >= patience:
                    stop_training = True
        if test_dataset and not stop_training:
            test_results = evaluate(args, model, test_dataset, test_emoji_tokens, mode="test", global_step=global_step, save_to_file=True)
            logger.info(f"Epoch {epoch + 1} - K·∫øt qu·∫£ test: {test_results}")
    # Sau khi stop, load l·∫°i model t·ªët nh·∫•t (n·∫øu mu·ªën)
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        logger.info("ƒê√£ load l·∫°i model v·ªõi F1-macro t·ªët nh·∫•t.")
    return global_step, tr_loss / max(global_step, 1)

class Args:
    data_dir = "/content/drive/MyDrive/Goemotions/data"
    train_file = "train.tsv"
    dev_file = "dev.tsv"
    test_file = "test.tsv"
    label_file = "labels.txt"
    config_dir = "/content/drive/MyDrive/Goemotions/config"
    taxonomy = "original"
    model_name_or_path = "monologg/bert-base-cased-goemotions-original"
    tokenizer_name_or_path = "monologg/bert-base-cased-goemotions-original"
    emoji2vec_path = "/content/drive/MyDrive/Goemotions/emoji2vec.txt"
    ckpt_dir = "/content/drive/MyDrive/Goemotions/checkpoints"
    output_dir = "output"
    max_seq_len = 128
    train_batch_size = 64
    eval_batch_size = 64
    num_train_epochs = 30
    learning_rate = 2e-5
    gradient_accumulation_steps = 1
    logging_steps = 500
    save_steps = 500
    do_train = True
    do_eval = True
    eval_all_checkpoints = False
    threshold = 0.3
    seed = 42
    task = "goemotions"

def main():
    args = Args()
    sys.stdout.flush()
    logger.debug("B·∫Øt ƒë·∫ßu ch∆∞∆°ng tr√¨nh ch√≠nh")
    args.output_dir = os.path.join(args.ckpt_dir, args.output_dir)
    os.makedirs(args.output_dir, exist_ok=True)
    init_logger(args)
    set_seed(args)
    try:
        processor = GoEmotionsProcessor(args)
        label_list = processor.get_labels()
        config = BertConfig.from_pretrained(
            args.model_name_or_path,
            num_labels=len(label_list),
            id2label={i: label for i, label in enumerate(label_list)},
            label2id={label: i for i, label in enumerate(label_list)}
        )
        tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name_or_path)
        model = BertForMultiLabelClassification.from_pretrained(
            args.model_name_or_path,
            config=config,
            emoji2vec_path=args.emoji2vec_path
        )
        args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(args.device)
        if args.do_train:
            train_dataset, train_emoji_tokens = load_and_cache_examples(args, tokenizer, mode="train")
            if train_dataset is None:
                logger.error("Kh√¥ng th·ªÉ t·∫£i d·ªØ li·ªáu train. Tho√°t!")
                return
        else:
            train_dataset, train_emoji_tokens = None, None
        if args.do_eval:
            dev_dataset, dev_emoji_tokens = load_and_cache_examples(args, tokenizer, mode="dev")
            test_dataset, test_emoji_tokens = load_and_cache_examples(args, tokenizer, mode="test")
        else:
            dev_dataset, dev_emoji_tokens = None, None
            test_dataset, test_emoji_tokens = None, None
        if args.do_train:
            global_step, tr_loss = train(args, model, tokenizer, train_dataset, train_emoji_tokens, dev_dataset, dev_emoji_tokens, test_dataset, test_emoji_tokens)
            logger.info(f"Ho√†n th√†nh hu·∫•n luy·ªán. B∆∞·ªõc to√†n c·ª•c: {global_step}, Loss trung b√¨nh: {tr_loss:.4f}")
        if args.do_eval and dev_dataset:
            results = evaluate(args, model, dev_dataset, dev_emoji_tokens, mode="dev", save_to_file=True)
            logger.info(f"K·∫øt qu·∫£ validation cu·ªëi c√πng: {results}")
        if args.eval_all_checkpoints:
            checkpoints = [os.path.join(args.output_dir, d) for d in os.listdir(args.output_dir) if d.startswith("checkpoint")]
            for checkpoint in checkpoints:
                global_step = checkpoint.split("-")[-1]
                model = BertForMultiLabelClassification.from_pretrained(checkpoint, config=config)
                model.to(args.device)
                results = evaluate(args, model, test_dataset, test_emoji_tokens, mode="test", global_step=global_step, save_to_file=True)
                logger.info(f"K·∫øt qu·∫£ test cho checkpoint {global_step}: {results}")
    except Exception as e:
        logger.error(f"L·ªói trong qu√° tr√¨nh th·ª±c thi ch√≠nh: {str(e)}")
        raise

if __name__ == "__main__":
    main()

DEBUG:__main__:B·∫Øt ƒë·∫ßu ch∆∞∆°ng tr√¨nh ch√≠nh
INFO:__main__:Logger ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o.
INFO:__main__:ƒê√£ ƒë·ªçc 28 nh√£n t·ª´ /content/drive/MyDrive/Goemotions/data/labels.txt
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /monologg/bert-base-cased-goemotions-original/resolve/main/config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /api/resolve-cache/models/monologg/bert-base-cased-goemotions-original/13c44c849132f82bb61188d909a574badffb27a3/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /monologg/bert-base-cased-goemotions-original/resolve/main/tokenizer_config.json HTTP/1.1" 307 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /api/resolve-cache/models/monologg/bert-base-cased-goemotions-original/13c44c849132f82bb61188d909a574badffb27a3/tokenizer_config.json HTTP/1.1" 200 0
INFO:gensim.models.keyedvectors:loading projection weights from /content/drive/MyDrive/G