In [None]:
import os
import re
import pickle
import random
import argparse
import logging
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import gensim
import string
import sys
from huggingface_hub import HfApi, HfFolder, upload_folder
from google.colab import userdata
import joblib
import emoji
import wordninja
import nlpaug.augmenter.word as naw
import json
import time

# Thi·∫øt l·∫≠p logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ƒê·ªãnh nghƒ©a c√°c √°nh x·∫° cho ti·ªÅn x·ª≠ l√Ω
CONTRACTION_MAPPING = {
    "ain't": "is not", "aren't": "are not", "can't": "cannot", "'cause": "because",
    "could've": "could have", "couldn't": "could not", "didn't": "did not",
    "doesn't": "does not", "don't": "do not", "hadn't": "had not",
    "hasn't": "has not", "haven't": "have not", "he'd": "he would",
    "he'll": "he will", "he's": "he is", "how'd": "how did",
    "how'd'y": "how do you", "how'll": "how will", "how's": "how is",
    "I'd": "I would", "I'd've": "I would have", "I'll": "I will",
    "I'll've": "I will have", "I'm": "I am", "I've": "I have",
    "i'd": "i would", "i'd've": "i would have", "i'll": "i will",
    "i'll've": "i will have", "i'm": "i am", "i've": "i have",
    "isn't": "is not", "it'd": "it would", "it'd've": "it would have",
    "it'll": "it will", "it'll've": "it will have", "it's": "it is",
    "let's": "let us", "ma'am": "madam", "mayn't": "may not",
    "might've": "might have", "mightn't": "might not",
    "mightn't've": "might not have", "must've": "must have",
    "mustn't": "must not", "mustn't've": "must not have",
    "needn't": "need not", "needn't've": "need not have",
    "o'clock": "of the clock", "oughtn't": "ought not",
    "oughtn't've": "ought not have", "shan't": "shall not",
    "sha'n't": "shall not", "shan't've": "shall not have",
    "she'd": "she would", "she'd've": "she would have",
    "she'll": "she will", "she'll've": "she will have", "she's": "she is",
    "should've": "should have", "shouldn't": "should not",
    "shouldn't've": "should not have", "so've": "so have", "so's": "so as",
    "this's": "this is", "that'd": "that would", "that'd've": "that would have",
    "that's": "that is", "there'd": "there would", "there'd've": "there would have",
    "there's": "there is", "here's": "here is", "they'd": "they would",
    "they'd've": "they would have", "they'll": "they will",
    "they'll've": "they will have", "they're": "they are",
    "they've": "they have", "to've": "to have", "wasn't": "was not",
    "we'd": "we would", "we'd've": "we would have", "we'll": "we will",
    "we'll've": "we will have", "we're": "we are", "we've": "we have",
    "weren't": "were not", "what'll": "what will", "what'll've": "what will have",
    "what're": "what are", "what's": "what is", "what've": "what have",
    "when's": "when is", "when've": "when have", "where'd": "where did",
    "where's": "where is", "where've": "where have", "who'll": "who will",
    "who'll've": "who will have", "who's": "who is", "who've": "who have",
    "why's": "why is", "why've": "why have", "will've": "will have",
    "won't": "will not", "won't've": "will not have", "would've": "would have",
    "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all",
    "y'all'd": "you all would", "y'all'd've": "you all would have",
    "y'all're": "you all are", "y'all've": "you all have", "you'd": "you would",
    "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have",
    "you're": "you are", "you've": "you have", "u.s": "america", "e.g": "for example"
}

PUNCT_MAPPING = {
    "‚Äò": "'", "‚Çπ": "e", "¬¥": "'", "¬∞": "", "‚Ç¨": "e", "‚Ñ¢": "tm", "‚àö": " sqrt ",
    "√ó": "x", "¬≤": "2", "‚Äî": "-", "‚Äì": "-", "‚Äô": "'", "_": "-", "`": "'", '‚Äú': '"',
    '‚Äù': '"', '‚Äú': '"', "¬£": "e", '‚àû': 'infinity', 'Œ∏': 'theta', '√∑': '/', 'Œ±': 'alpha',
    '‚Ä¢': '.', '√†': 'a', '‚àí': '-', 'Œ≤': 'beta', '‚àÖ': '', '¬≥': '3', 'œÄ': 'pi'
}

MISPELL_DICT = {
    'colour': 'color', 'centre': 'center', 'favourite': 'favorite', 'travelling': 'traveling',
    'counselling': 'counseling', 'theatre': 'theater', 'cancelled': "canceled", 'labour': 'labor',
    'organisation': "organization", 'wwii': 'world war 2', 'citicise': 'criticize', 'youtu ': 'youtube ',
    'Qoura': 'Quora', 'sallary': 'salary', 'Whta': 'What', 'narcisist': 'narcissist', 'howdo': 'how do',
    'whatare': 'what are', 'howcan': "how can", 'howmuch': 'how much', 'howmany': 'how many',
    'whydo': 'why do', 'doI': 'do I', 'theBest': 'the best', 'howdoes': 'how does',
    'mastrubation': 'masturbation', 'mastrubate': 'masturbate', 'mastrubating': 'masturbating',
    'pennis': 'penis', 'Etherium': 'Ethereum', 'narcissit': 'narcissist', 'bigdata': 'big data',
    '2k17': '2017', '2k18': '2018', 'qouta': 'quota', 'exboyfriend': 'ex boyfriend',
    'airhostess': 'air hostess', 'whst': 'what', 'watsapp': 'whatsapp',
    'demonitisation': 'demonetization', 'demonitization': 'demonetization',
    'demonetisation': "demonetization", 'pissed': 'pissed'
}

PUNCT_CHARS = list((set(string.punctuation) | {
    "‚Äô", "‚Äò", "‚Äì", "‚Äî", "~", "|", "‚Äú", "‚Äù", "‚Ä¶", "'", "`", "_", "‚Äú"
}) - set(["#", "!", "?", "üòä", "üéâ", "üò¢", "üíî", "üòç", "‚ú®", "ü§Ø", "ü§î", "üò°", "üî•"]))
PUNCT_CHARS.sort()
PUNCTUATION = "".join(PUNCT_CHARS)
REPLACE_PUNCT = re.compile("[%s]" % re.escape(PUNCTUATION))

def load_embeddings(glove_path: str, emoji2vec_path: str, embedding_dim: int, logger: logging.Logger) -> tuple:
    logger.info(f"Loading GloVe from {glove_path}")
    glove_embeddings = {}
    try:
        with open(glove_path, 'r', encoding='utf-8') as f:
            for line in f:
                values = line.strip().split()
                if len(values) < embedding_dim + 1:
                    continue
                word = values[0]
                vector = np.asarray(values[1:], dtype='float32')
                glove_embeddings[word] = vector
        logger.info(f"Loaded {len(glove_embeddings)} GloVe vectors")
    except Exception as e:
        logger.error(f"Error loading GloVe from {glove_path}: {str(e)}")
        raise

    logger.info(f"Loading emoji2vec from {emoji2vec_path}")
    try:
        emoji2vec = gensim.models.KeyedVectors.load_word2vec_format(
            emoji2vec_path, binary=False, unicode_errors='ignore'
        )
        logger.info(f"Loaded {len(emoji2vec.key_to_index)} emoji vectors")
    except Exception as e:
        logger.error(f"Error loading emoji2vec from {emoji2vec_path}: {str(e)}")
        raise
    return glove_embeddings, emoji2vec

def clean_text(text: str, logger: logging.Logger) -> str:
    if not isinstance(text, str) or not text.strip():
        return ""

    for contraction, full_form in CONTRACTION_MAPPING.items():
        text = text.replace(contraction, full_form)

    for p, replacement in PUNCT_MAPPING.items():
        text = text.replace(p, replacement)

    def split_hashtag(match):
        hashtag = match.group(0)[1:]
        words = wordninja.split(hashtag)
        return ' '.join(words)
    text = re.sub(r"#\w+", split_hashtag, text)

    text = re.sub(r"http\S*|\S*\.com\S*|\S*www\S*|\s@\S+", " ", text)
    text = REPLACE_PUNCT.sub(" ", text)

    text = text.lower()
    words = text.split()
    words = [MISPELL_DICT.get(word, word) for word in words]
    text = ' '.join(words)
    text = re.sub(r"\s+", " ", text).strip()

    return text

def create_features(texts: list, tokenizer: Tokenizer, glove_embeddings: dict, emoji2vec: gensim.models.KeyedVectors,
                   embedding_dim: int, max_len: int, logger: logging.Logger) -> np.ndarray:
    sequences = tokenizer.texts_to_sequences(texts)
    padded = pad_sequences(sequences, maxlen=max_len, padding="post", truncating="post")
    features = np.zeros((len(texts), embedding_dim), dtype=np.float32)

    for i, seq in enumerate(padded):
        vectors = []
        for idx in seq:
            if idx == 0:
                continue
            word = tokenizer.index_word.get(idx, "<OOV>")
            if emoji.is_emoji(word) and word in emoji2vec:
                vectors.append(emoji2vec[word])
            elif word in glove_embeddings:
                vectors.append(glove_embeddings[word])
        if vectors:
            features[i] = np.mean(vectors, axis=0)

    logger.info(f"Created features for {len(texts)} samples, shape: {features.shape}")
    return features

class GoemotionsProcessor:
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger

    def get_labels(self) -> list:
        label_file = os.path.join(self.args.data_dir, self.args.label_file)
        self.logger.info(f"Reading labels from {label_file}")
        if not os.path.exists(label_file):
            self.logger.warning(f"Label file {label_file} not found. Using default GoEmotions labels.")
            labels = [
                'admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion',
                'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment',
                'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism',
                'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'
            ]
            return labels
        with open(label_file, "r", encoding="utf-8") as f:
            labels = [line.strip() for line in f if line.strip()]
        if not labels:
            self.logger.error(f"Label file {label_file} is empty")
            raise ValueError(f"Label file {label_file} is empty")
        self.logger.info(f"Loaded {len(labels)} labels")
        return labels

    def _read_file(self, input_file: str) -> pd.DataFrame:
        if not os.path.exists(input_file):
            self.logger.error(f"Data file {input_file} not found")
            raise FileNotFoundError(f"Data file {input_file} not found")
        try:
            df = pd.read_csv(input_file, sep='\t', header=None, names=['text', 'labels', 'id'])
            if df.empty:
                self.logger.error(f"Data file {input_file} is empty")
                raise ValueError(f"Data file {input_file} is empty")
            self.logger.info(f"Read {len(df)} lines from {input_file}")
            return df
        except Exception as e:
            self.logger.error(f"Error reading file {input_file}: {str(e)}")
            raise

    def _augment_data(self, texts: list, labels: list) -> tuple:
        aug = naw.SynonymAug(aug_p=0.3)
        augmented_texts, augmented_labels = [], []
        for text, label in zip(texts, labels):
            augmented_texts.append(text)
            augmented_labels.append(label)
            aug_text = aug.augment(text)[0]
            augmented_texts.append(aug_text)
            augmented_labels.append(label)
        self.logger.info(f"Augmented to {len(augmented_texts)} samples")
        return augmented_texts, augmented_labels

    def _balance_labels(self, examples: list, label_list_len: int, set_type: str) -> list:
        if set_type != "train":
            return examples
        self.logger.info("Balancing labels for training set")
        label_counts = Counter()
        for ex in examples:
            label_counts.update(ex['labels'])

        counts = [count for count in label_counts.values() if count > 0]
        target_count = min(int(np.median(counts) * 6.0), len(examples) // 2)
        balanced_examples = []
        for label in range(label_list_len):
            samples_with_label = [ex for ex in examples if label in ex['labels']]
            current_count = label_counts.get(label, 0)
            if current_count == 0:
                continue
            elif current_count < target_count:
                samples_needed = target_count - current_count
                oversampled = random.choices(samples_with_label, k=samples_needed)
                balanced_examples.extend(oversampled)
            else:
                samples_to_keep = max(target_count, int(current_count * 0.5))
                balanced_examples.extend(random.sample(samples_with_label, min(samples_to_keep, len(samples_with_label))))
        balanced_examples.extend([ex for ex in examples if not any(label in ex['labels'] for label in range(label_list_len))])
        random.shuffle(balanced_examples)
        self.logger.info(f"Balanced from {len(examples)} to {len(balanced_examples)} samples")
        return balanced_examples

    def get_examples(self, mode: str) -> list:
        file_map = {'train': self.args.train_file, 'dev': self.args.dev_file, 'test': self.args.test_file}
        file_to_read = file_map.get(mode)
        if not file_to_read:
            raise ValueError("Mode must be 'train', 'dev', or 'test'")
        file_path = os.path.join(self.args.data_dir, file_to_read)
        df = self._read_file(file_path)
        return self._create_examples(df, mode)

    def _create_examples(self, df: pd.DataFrame, set_type: str) -> list:
        examples = []
        label_list_len = len(self.get_labels())
        label_counts = Counter()
        for i, row in df.iterrows():
            guid = f"{set_type}-{i}"
            raw_text = row['text']
            label_str = str(row['labels'])
            try:
                label = [int(l) for l in label_str.split(',') if l.strip().isdigit()]
                label = [l for l in label if 0 <= l < label_list_len]
                if not label:
                    self.logger.warning(f"No valid labels at line {i} (text: {raw_text}, labels: {label_str}). Skipping.")
                    continue
                label_counts.update(label)
            except (ValueError, IndexError) as e:
                self.logger.warning(f"Invalid labels at line {i} (text: {raw_text}, labels: {label_str}). Error: {e}. Skipping.")
                continue
            cleaned_text = clean_text(raw_text, self.logger)
            if not cleaned_text:
                self.logger.warning(f"Empty text after cleaning at line {i} (original: {raw_text}). Skipping.")
                continue
            examples.append({'guid': guid, 'text': cleaned_text, 'labels': label})

        if not examples:
            self.logger.error(f"No valid examples created from {set_type}")
            raise ValueError(f"No valid examples created from {set_type}")

        self.logger.info(f"Created {len(examples)} examples from {set_type}")
        self.logger.info(f"Label distribution for {set_type} (before balancing): {dict(label_counts)}")
        return self._balance_labels(examples, label_list_len, set_type)

def load_data(args, logger: logging.Logger) -> tuple:
    processor = GoemotionsProcessor(args, logger)
    label_list = processor.get_labels()
    def load_and_cache(mode: str) -> tuple:
        cached_file = os.path.join(args.data_dir, f"cached_{mode}_data.pkl")
        if os.path.exists(cached_file):
            with open(cached_file, 'rb') as f:
                data = pickle.load(f)
            return data['texts'], data['labels']
        examples = processor.get_examples(mode)
        texts = [ex['text'] for ex in examples]
        labels = [ex['labels'] for ex in examples]
        if mode == 'train' and args.max_train_samples:
            indices = random.sample(range(len(texts)), min(args.max_train_samples, len(texts)))
            texts = [texts[i] for i in indices]
            labels = [labels[i] for i in indices]
        if mode == 'train':
            texts, labels = processor._augment_data(texts, labels)
        with open(cached_file, 'wb') as f:
            pickle.dump({'texts': texts, 'labels': labels}, f)
        return texts, labels
    train_texts, train_labels = load_and_cache('train')
    val_texts, val_labels = load_and_cache('dev')
    test_texts, test_labels = load_and_cache('test')
    return train_texts, val_texts, test_texts, train_labels, val_labels, test_labels, label_list

def to_multi_hot(label_lists: list, num_labels: int, logger: logging.Logger) -> np.ndarray:
    m = np.zeros((len(label_lists), num_labels), dtype=np.int32)
    for i, labs in enumerate(label_lists):
        valid_labs = []
        for lab in labs:
            try:
                lab_int = int(lab)
                if 0 <= lab_int < num_labels:
                    valid_labs.append(lab_int)
                else:
                    logger.warning(f"Invalid label index {lab_int} at sample {i} (out of range [0, {num_labels-1}]). Skipping.")
            except (ValueError, TypeError):
                logger.warning(f"Non-integer label {lab} at sample {i}. Skipping.")
        if not valid_labs:
            logger.warning(f"No valid labels for sample {i}. Using empty label set.")
            continue
        m[i, valid_labs] = 1
    return m

def compute_class_weights(labels: list, num_labels: int, logger: logging.Logger) -> dict:
    label_counts = np.zeros(num_labels)
    for labs in labels:
        for l in labs:
            if 0 <= l < num_labels:
                label_counts[l] += 1
    epsilon = 1e-8
    label_counts = np.maximum(label_counts, epsilon)
    total_samples = len(labels)
    class_weights = {}
    median_count = np.median(label_counts)
    for i in range(num_labels):
        class_weights[i] = np.log(total_samples / label_counts[i]) if label_counts[i] > 0 else 1.0
        if label_counts[i] < median_count:
            class_weights[i] *= 2.0
        class_weights[i] = max(class_weights[i], 1.0)
    weight_sum = sum(class_weights.values())
    if weight_sum > 0:
        scale_factor = num_labels / weight_sum
        for i in range(num_labels):
            class_weights[i] *= scale_factor
    return class_weights

def predict_examples(model, tokenizer: Tokenizer, examples: list, label_list: list, glove_embeddings: dict,
                    emoji2vec: gensim.models.KeyedVectors, embedding_dim: int, max_len: int, logger: logging.Logger) -> list:
    if not examples:
        logger.warning("No test examples provided")
        return []

    cleaned_examples = [clean_text(ex, logger) for ex in examples]
    logger.info(f"Preprocessed texts: {cleaned_examples}")
    features = create_features(cleaned_examples, tokenizer, glove_embeddings, emoji2vec, embedding_dim, max_len, logger)
    pred_probs = model.decision_function(features)

    thresholds = np.percentile(pred_probs, 75, axis=1)
    predictions = np.array([pred_probs[i] >= thresholds[i] for i in range(len(pred_probs))]).astype(int)

    results = []
    for i, example in enumerate(examples):
        predicted_labels = [label_list[j] for j in range(len(label_list)) if predictions[i][j] == 1]
        top_indices = np.argsort(pred_probs[i])[-3:][::-1]
        top_labels = [label_list[j] for j in top_indices]
        top_probs = [pred_probs[i][j] for j in top_indices]
        print(f"\nExample {i+1}: {example}")
        print(f"Predicted (dynamic threshold): {predicted_labels if predicted_labels else 'No label'}")
        print(f"Top-3 labels: {list(zip(top_labels, top_probs))}")
        results.append({
            "text": example,
            "labels": predicted_labels if predicted_labels else "No label",
            "top_labels": list(zip(top_labels, top_probs)),
            "probs": pred_probs[i].tolist()
        })
    return results

def main():
    parser = argparse.ArgumentParser(description="SVM Multi-Label Classification for GoEmotions or VAAFI")
    parser.add_argument("--data_dir", default="/content/drive/MyDrive/Goemotions/data", type=str)
    parser.add_argument("--train_file", default="train0.tsv", type=str)
    parser.add_argument("--dev_file", default="dev0.tsv", type=str)
    parser.add_argument("--test_file", default="test0.tsv", type=str)
    parser.add_argument("--label_file", default="labels.txt", type=str)
    parser.add_argument("--glove_path", default="/content/drive/MyDrive/Goemotions/glove.6B.300d.txt", type=str)
    parser.add_argument("--emoji2vec_path", default="/content/drive/MyDrive/Goemotions/emoji2vec.txt", type=str)
    parser.add_argument("--ckpt_dir", default="/content/drive/MyDrive/Goemotions/checkpoints", type=str)
    parser.add_argument("--model_type", default="goemotions-svm", type=str)
    parser.add_argument("--hf_repo_id", default="Songnguyen263/{model_type}", type=str)
    parser.add_argument("--hf_token", default=None, type=str, help="Hugging Face API token")
    parser.add_argument("--max_train_samples", type=int, default=None, help="Maximum number of training samples to use")
    parser.add_argument("--max_iter", type=int, default=500, help="Maximum iterations for SVM training")

    args = parser.parse_args([arg for arg in sys.argv[1:] if not arg.startswith('-f') and not arg.endswith('.json')])
    args.hf_repo_id = args.hf_repo_id.format(model_type=args.model_type)

    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    logger.info("Google Drive mounted successfully")

    os.makedirs(args.ckpt_dir, exist_ok=True)

    train_texts, val_texts, test_texts, train_labels, val_labels, test_labels, label_list = load_data(args, logger)
    num_labels = len(label_list)
    logger.info(f"Number of labels: {num_labels}")
    logger.info(f"Training samples: {len(train_texts)}, Validation samples: {len(val_texts)}, Test samples: {len(test_texts)}")

    # Validate labels
    for split, labels in [("train", train_labels), ("val", val_labels), ("test", test_labels)]:
        invalid_labels = [labs for labs in labels if any(not isinstance(lab, int) or lab < 0 or lab >= num_labels for lab in labs)]
        if invalid_labels:
            logger.error(f"Found {len(invalid_labels)} invalid label sets in {split} split: {invalid_labels[:5]}")
            raise ValueError(f"Invalid labels detected in {split} split")

    y_train = to_multi_hot(train_labels, num_labels, logger)
    y_val = to_multi_hot(val_labels, num_labels, logger)
    y_test = to_multi_hot(test_labels, num_labels, logger)

    vocab_size = 30000
    max_len = 100
    embedding_dim = 300
    tokenizer = Tokenizer(num_words=vocab_size, oov_token="<OOV>", filters='!"$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n')
    tokenizer.fit_on_texts(train_texts)

    with open(os.path.join(args.ckpt_dir, f"{args.model_type}_tokenizer.pkl"), "wb") as f:
        pickle.dump(tokenizer, f)

    glove_embeddings, emoji2vec = load_embeddings(args.glove_path, args.emoji2vec_path, embedding_dim, logger)
    X_train = create_features(train_texts, tokenizer, glove_embeddings, emoji2vec, embedding_dim, max_len, logger)
    X_val = create_features(val_texts, tokenizer, glove_embeddings, emoji2vec, embedding_dim, max_len, logger)
    X_test = create_features(test_texts, tokenizer, glove_embeddings, emoji2vec, embedding_dim, max_len, logger)

    class_weights = compute_class_weights(train_labels, num_labels, logger)
    model = OneVsRestClassifier(LinearSVC(C=1.0, class_weight=class_weights, max_iter=args.max_iter, tol=1e-3))

    logger.info("Starting SVM training...")
    start_time = time.time()
    try:
        model.fit(X_train, y_train)
        logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")
    except KeyboardInterrupt:
        logger.warning("Training interrupted. Saving partial model...")
        joblib.dump(model, os.path.join(args.ckpt_dir, f"{args.model_type}_partial_model.joblib"))
        logger.info(f"Partial model saved to {args.ckpt_dir}/{args.model_type}_partial_model.joblib")
        raise

    pred_probs = model.decision_function(X_test)
    thresholds = np.percentile(pred_probs, 75, axis=1)
    y_pred = np.array([pred_probs[i] >= thresholds[i] for i in range(len(pred_probs))]).astype(int)

    raw_acc = accuracy_score(y_test.flatten(), y_pred.flatten())
    micro_f1 = f1_score(y_test, y_pred, average='micro', zero_division=0)
    macro_f1 = f1_score(y_test, y_pred, average='macro', zero_division=0)
    weighted_f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
    report = classification_report(y_test, y_pred, target_names=label_list, zero_division=0, digits=4)

    logger.info(f"Raw Accuracy: {raw_acc:.4f}")
    logger.info(f"Micro F1: {micro_f1:.4f}")
    logger.info(f"Macro F1: {macro_f1:.4f}")
    logger.info(f"Weighted F1: {weighted_f1:.4f}")
    logger.info(f"\nClassification Report:\n{report}")

    test_examples = [
        "Feeling on top of the world today! üéâüòä #BestDayEver #SoHappy",
        "Totally let down... üò¢üíî #Disappointed #WhyThis",
        "Omg that‚Äôs incredible news! üòç‚ú® #Amazing #Grateful",
        "Head full of thoughts rn ü§Øü§î #Confused #Overthinking",
        "Still can‚Äôt believe this happened. So pissed! üò°üî• #Angry #Unbelievable"
    ]
    predict_examples(model, tokenizer, test_examples, label_list, glove_embeddings, emoji2vec, embedding_dim, max_len, logger)

    output_dir = "/content/drive/MyDrive/Goemotions/"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "SVM-results.txt")
    with open(output_file, "a", encoding="utf-8") as f:
        f.write(f"\n=== SVM Results - {args.model_type} - {pd.Timestamp.now()} ===\n")
        f.write(f"Raw Accuracy: {raw_acc:.4f}\n")
        f.write(f"Micro F1: {micro_f1:.4f}\n")
        f.write(f"Macro F1: {macro_f1:.4f}\n")
        f.write(f"Weighted F1: {weighted_f1:.4f}\n")
        f.write(f"\nClassification Report:\n{report}\n")

    local_model_dir = os.path.join(args.ckpt_dir, f"hf_{args.model_type}")
    os.makedirs(local_model_dir, exist_ok=True)
    joblib.dump(model, os.path.join(local_model_dir, f"{args.model_type}_model.joblib"))
    with open(os.path.join(local_model_dir, f"{args.model_type}_tokenizer.pkl"), "wb") as f:
        pickle.dump(tokenizer, f)
    with open(os.path.join(local_model_dir, f"{args.model_type}_labels.txt"), "w", encoding="utf-8") as f:
        for label in label_list:
            f.write(f"{label}\n")
    config = {
        "vocab_size": vocab_size,
        "max_len": max_len,
        "embedding_dim": embedding_dim,
        "num_labels": num_labels,
        "model_type": args.model_type
    }
    with open(os.path.join(local_model_dir, f"{args.model_type}_config.json"), "w") as f:
        json.dump(config, f, indent=4)

    if args.hf_token:
        HfFolder.save_token(args.hf_token)
        api = HfApi()
        api.upload_folder(
            folder_path=local_model_dir,
            repo_id=args.hf_repo_id,
            repo_type="model",
            commit_message=f"Upload {args.model_type} model and tokenizer"
        )
        logger.info(f"Uploaded model to Hugging Face: {args.hf_repo_id}")
    else:
        logger.warning("No Hugging Face token provided. Skipping upload.")

if __name__ == "__main__":
    main()