## **Tóm tắt những phần thay đổi**

1. **Phần import thư viện thêm:**

   ```python
   import nltk
   nltk.download('wordnet', quiet=True)
   nltk.download('omw-1.4', quiet=True)
   ```

2. **Những phần có ghi "**bản mới"**" sẽ *thay thế cho các đoạn code đánh dấu "bản cũ"***, và **không dùng đến các đoạn *bản cũ* nữa**.

In [1]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.11.0.post1


In [2]:
!pip install flaml kagglehub pandas scikit-learn

Collecting flaml
  Downloading FLAML-2.3.5-py3-none-any.whl.metadata (16 kB)
Downloading FLAML-2.3.5-py3-none-any.whl (322 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.2/322.2 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: flaml
Successfully installed flaml-2.3.5


In [3]:
import pandas as pd
import numpy as np
import random
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import faiss
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
from datetime import datetime
from collections import Counter
import re
import warnings
import gdown
import kagglehub
from kagglehub import KaggleDatasetAdapter
warnings.filterwarnings('ignore')

import nltk
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

True

##Data Augmentation (bản mới)

In [5]:
# ==============================
# Class sinh dữ liệu tinh vi để augmentation
# ==============================
import pandas as pd
import random
from nltk.corpus import wordnet
import requests

class HardExampleGenerator:
    def __init__(self, dataset_path, alpha_spam=0.5, alpha_ham=0.3, use_llm_phrases=False):
        """
        Args:
            dataset_path (str): đường dẫn file CSV chứa cột 'Message' và 'Category'
            alpha_spam (float): tỷ lệ nhân bản spam khi augment
            alpha_ham (float): tỷ lệ nhân bản ham khi augment
            use_llm_phrases (bool): nếu True thì chờ load LLM phrases sau bằng load_llm_phrases()
        """
        self.dataset_path = dataset_path
        self.alpha_spam = alpha_spam
        self.alpha_ham = alpha_ham
        self.df = pd.read_csv(dataset_path)

        # Nếu chưa có LLM phrases thì dùng cụm mặc định
        if not use_llm_phrases:
            self.spam_groups = self._init_spam_phrases()
            self.ham_groups = self._init_ham_phrases()
        else:
            # Khởi tạo rỗng, sau sẽ gán bằng load_llm_phrases()
            self.spam_groups = []
            self.ham_groups = []

    # Dùng cho cách 2 (có thể lấy từ LLM bên ngoài xịn hơn)
    def _init_spam_phrases(self):
        # Các cụm spam tinh vi (giống file gốc)
        # ----- 7 nhóm dấu hiệu spam -----
        financial_phrases = [
            "you get $100 back", "they refund $200 instantly",
            "limited $50 bonus for early registration", "earn $150/day remote work",
            "approved for a $500 credit", "quick $300 refund if you confirm",
            "they give $250 cashback if you check in early",
            "your account gets $100 instantly after confirmation",
            "instant $400 transfer if you reply YES today",
            "exclusive $600 grant approved for you"
        ]

        promotion_phrases = [
            "limited time offer ends tonight", "buy one get one free today only",
            "exclusive deal just for you", "hot sale up to 80% off",
            "flash sale starting in 2 hours", "new collection, free shipping worldwide",
            "best price guaranteed for early birds", "special discount coupon for first 100 buyers",
            "reserve now and get extra 20% off", "only 3 items left, order now!"
        ]

        lottery_phrases = [
            "congratulations! you’ve won a $1000 gift card", "you are selected to receive a free iPhone",
            "claim your $500 Amazon voucher now", "winner! reply to confirm your prize",
            "spin the wheel to win exciting gifts", "lucky draw winner – act fast",
            "redeem your exclusive prize today", "final reminder: unclaimed reward waiting",
            "instant gift unlocked, tap to get", "biggest jackpot giveaway this week"
        ]

        scam_alert_phrases = [
            "your account will be suspended unless verified", "unusual login detected, reset password now",
            "security update required immediately", "urgent: payment failed, update details now",
            "verify your identity to avoid account closure", "your Netflix subscription is on hold, confirm payment",
            "important: unauthorized activity detected", "bank alert: confirm transaction or account locked",
            "last warning: confirm within 24 hours", "emergency: suspicious access blocked, verify"
        ]

        call_to_action_phrases = [
            "click here to confirm", "reply YES to activate bonus",
            "register before midnight and win", "tap now to claim your reward",
            "sign up today, limited seats", "confirm immediately to proceed",
            "act fast, offer expires soon", "verify email to continue",
            "download the app and get free points", "complete payment within 12 hours"
        ]

        social_engineering_phrases = [
            "hey grandma, i need $500 for hospital bills", "hi mom, send money asap, phone broke",
            "boss asked me to buy 3 gift cards urgently", "john, can you transfer $300 now, emergency",
            "it’s me, your cousin, stuck abroad, need help", "friend, please help me with $200 loan",
            "hi, i lost my wallet, send $150 to this account", "urgent! i can’t talk now, send cash fast",
            "help me pay this fine, will return tomorrow", "sister, please pay $400 for my surgery"
        ]

        obfuscated_phrases = [
            "Cl!ck h3re t0 w1n fr€e iPh0ne", "G€t y0ur r3fund n0w!!!",
            "L!mited 0ff3r: Fr33 $$$ r3ward", "C@shb@ck av@il@ble t0d@y",
            "W!n b!g pr!ze, act f@st", "Cl@im y0ur 100% b0nus",
            "Fr33 g!ft w!th 0rder", "Up t0 $5000 r3fund @pprov3d",
            "R3ply N0W t0 r3c3ive $$$", "Urg3nt!!! C0nfirm d3tails 1mm3di@tely"
        ]

        # Gom các nhóm vào 1 danh sách
        spam_phrase_groups = [
            financial_phrases, promotion_phrases, lottery_phrases,
            scam_alert_phrases, call_to_action_phrases,
            social_engineering_phrases, obfuscated_phrases
        ]
        return spam_phrase_groups

    # Dùng cho cách 2 (có thể lấy từ LLM bên ngoài xịn hơn)
    def _init_ham_phrases(self):
        # Các cụm ham dễ gây hiểu nhầm
        # ----- 7 nhóm cụm dễ gây hiểu nhầm thành spam (giống spam phrases) -----
        financial_phrases = [
            "I got $100 cashback yesterday", "The bank refunded me $200 already",
            "I earned $150/day last month from freelancing", "Approved for $500 loan finally",
            "Got quick $300 refund after confirmation", "The store gave me $250 cashback",
            "My account got $100 instantly after confirming", "I received instant $400 transfer today",
            "They sent me exclusive $600 grant, lol", "Netflix actually gave me 3 months free"
        ]

        promotion_phrases = [
            "I bought one and got one free, legit deal", "Flash sale 80% off, I already ordered",
            "Exclusive deal worked for me, saved a lot", "Hot sale 2 hours ago, crazy cheap",
            "New collection free shipping, I tried it", "Best price ever for members",
            "Got special coupon, it worked!", "Reserved early and saved 20%",
            "Only 3 items left when I bought mine", "Order now, it’s real not fake"
        ]

        lottery_phrases = [
            "I actually won a $1000 voucher at the mall", "I got a free iPhone from the lucky draw",
            "Claimed my $500 Amazon voucher legit", "Won a prize, just showed my ticket",
            "Spun the wheel at the fair and got gifts", "Lucky draw worked for me today",
            "Redeemed my exclusive prize at the shop", "They reminded me to collect my reward",
            "Gift unlocked at the event, so fun", "Jackpot giveaway, real not scam"
        ]

        scam_alert_phrases = [
            "I got unusual login alert, but it was me", "Reset my password after warning, fine now",
            "Got security update mail, confirmed it’s real", "Payment failed once, updated and ok now",
            "Had to verify identity, bank confirmed legit", "Netflix on hold but paid, no issue",
            "Bank asked to confirm transaction, was me", "Warning mail yesterday, false alarm",
            "Confirmed within 24h, all safe", "Suspicious access blocked, just me traveling"
        ]

        call_to_action_phrases = [
            "I clicked to confirm and it worked", "Replied YES, bonus legit",
            "Registered before midnight, no scam", "Tapped link, claimed reward legit",
            "Signed up today, limited seat real", "Confirmed immediately, nothing shady",
            "Acted fast, got discount legit", "Verified email, safe and done",
            "Downloaded app, free points real", "Paid within 12 hours, successful"
        ]

        social_engineering_phrases = [
            "Mom, don’t worry I sent you $500 hospital bill already", "Hi mom, phone broke but friend helped",
            "Boss asked me to buy gift cards for office, already did", "John, I transferred $300, check it",
            "Cousin stuck abroad, we sent help", "Friend lent me $200 last week, repaid",
            "Lost wallet but someone returned $150", "Urgent cash request yesterday, sorted now",
            "Helped pay fine, friend returned", "Sister’s surgery done, paid $400 legit"
        ]

        obfuscated_phrases = [
            "Clicked h3re to win fr€e gift, real promo", "Got r3fund n0w!!! 100% legit",
            "Fr33 reward worked, tried it", "C@shb@ck real, used today",
            "Won prize real, not spam", "Cl@imed b0nus myself, safe",
            "Gift order legit, no scam", "Refund approved @ bank, no issue",
            "Replied N0W got $$$ legit", "Urg3nt confirm done, real bank"
        ]

        hard_ham_phrase_groups = [
            financial_phrases, promotion_phrases, lottery_phrases,
            scam_alert_phrases, call_to_action_phrases,
            social_engineering_phrases, obfuscated_phrases
        ]
        return hard_ham_phrase_groups

    # Dùng cho cách 1
    def generate_like_spam_ham(self, label='spam', n_per_group=10, api_key=None, model="mistralai/Mixtral-8x7B-Instruct-v0.1", group=None):
        """
        Sinh các câu spam/ham tinh vi mô phỏng người dùng, chia theo 7 nhóm phổ biến (70 câu total).

        Args:
            label (str): 'spam' hoặc 'ham'
            n_per_group (int): số câu trên mỗi nhóm
            api_key (str): Together.ai API key
            model (str): Model ID (Mixtral, LLaMA3,...)
            group (str or None): nếu chỉ muốn sinh 1 nhóm, chọn từ:
                'financial', 'promotion', 'lottery', 'scam_alert',
                'call_to_action', 'social_engineering', 'obfuscated'

        Returns:
            List[str]: Danh sách câu được sinh
        """
        if api_key is None:
            raise ValueError("❌ Cần cung cấp API key Together.ai.")

        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }

        group_prompts = {
            "financial": {
                "spam": "Generate realistic user-style spam messages that pretend to offer cashback, refunds, or financial rewards, but are actually deceptive.",
                "ham": "Generate legitimate human messages that mention refunds, cashback, or money transfers in real-life, harmless contexts."
            },
            "promotion": {
                "spam": "Generate spammy messages that appear friendly but are disguised promotions, sales, or limited-time offers.",
                "ham": "Generate genuine user messages that talk about real promotions or sales they used, sounding casual and truthful."
            },
            "lottery": {
                "spam": "Generate scam-like messages that claim the user won a lottery, prize, or giveaway — but in a deceptive, subtle tone.",
                "ham": "Generate honest messages where users talk about actually winning something in real life — malls, fairs, etc."
            },
            "scam_alert": {
                "spam": "Generate deceptive user-style spam about account alerts, security warnings, or password issues to trick the recipient.",
                "ham": "Generate real user messages where people talk about security alerts or login issues they experienced, in normal tone."
            },
            "call_to_action": {
                "spam": "Write spam messages with subtle calls to action like 'click here', 'register', or 'confirm' hidden in casual tone.",
                "ham": "Write normal human messages that mention clicking links or confirming actions, but are not spam."
            },
            "social_engineering": {
                "spam": "Generate spam messages that use fake urgency or personal relationships (e.g., 'Mom', 'Boss', 'Friend') to request money.",
                "ham": "Generate real messages from people who had real emergencies or money transfers, in personal tone."
            },
            "obfuscated": {
                "spam": "Write spam messages that use obfuscated text like '$$$', 'Fr33', 'Cl!ck', to bypass filters but sound human.",
                "ham": "Write real human messages that coincidentally use symbols or strange formats, but are not spam."
            }
        }

        selected_groups = [group] if group else list(group_prompts.keys())
        all_outputs = []

        for g in selected_groups:
            system_prompt = group_prompts[g][label]
            full_prompt = f"{system_prompt}\nGenerate {n_per_group} examples. Output only the messages, one per line."

            payload = {
                "model": model,
                "prompt": full_prompt,
                "max_tokens": 1000,
                "temperature": 0.9,
                "top_p": 0.95
            }

            print(f"📡 Generating {label.upper()} – Group: {g} ...")

            response = requests.post("https://api.together.xyz/v1/completions", headers=headers, json=payload)

            if response.ok:
                raw_output = response.json()["choices"][0]["text"].strip()
                lines = [line.strip("-•* ") for line in raw_output.splitlines() if line.strip()]
                all_outputs.extend(lines)
            else:
                raise RuntimeError(f"❌ API error @group {g}: {response.status_code} - {response.text}")

        return all_outputs

    # Dùng cho cách 1
    def load_llm_phrases(self, spam_list, ham_list, group_size=10):
        """
        Từ danh sách 70 câu spam + 70 câu ham, chia thành 7 nhóm (mỗi nhóm 10 câu).
        Dùng thay cho _init_spam_phrases() và _init_ham_phrases()

        Args:
            spam_list (list[str]): Danh sách 70 câu spam từ LLM
            ham_list (list[str]): Danh sách 70 câu ham từ LLM
            group_size (int): Số câu mỗi nhóm (mặc định 10)

        Tác dụng:
            Gán trực tiếp vào self.spam_groups và self.ham_groups
        """
        #assert len(spam_list) == len(ham_list) == 70, "❌ Cần đúng 70 câu mỗi loại để chia nhóm."
        spam_list = spam_list[:70]
        ham_list = ham_list[:70]
        self.spam_groups = [spam_list[i:i+group_size] for i in range(0, 70, group_size)]
        self.ham_groups = [ham_list[i:i+group_size] for i in range(0, 70, group_size)]
        print("✅ Đã load 140 câu LLM và chia thành 7 nhóm spam/ham.")
        return self.spam_groups, self.ham_groups

    def _generate_sentences(self, base_texts, phrase_groups, n):
        results = []
        for _ in range(n):
            base = random.choice(base_texts)
            insert = random.choice(random.choice(phrase_groups))
            sentence = f"{insert}. {base}" if random.random() < 0.5 else f"{base}, btw {insert}."
            results.append(sentence)
        return results

    def generate_hard_spam(self, output_path="/content/hard_spam_generated_auto.csv"):
        num_ham = self.df[self.df["Category"] == "ham"].shape[0]
        num_spam = self.df[self.df["Category"] == "spam"].shape[0]
        if num_spam >= num_ham:
            print("✅ Spam đã đủ, không sinh thêm.")
            return []
        n_generate = int((num_ham - num_spam) * self.alpha_spam)
        base_texts = self.df[self.df["Category"] == "ham"]["Message"].sample(n=n_generate, random_state=42).tolist()
        generated = self._generate_sentences(base_texts, self.spam_groups, n_generate)
        pd.DataFrame({"Category": ["spam"] * n_generate, "Message": generated}).to_csv(output_path, index=False)
        print(f"✅ Sinh {n_generate} hard spam -> {output_path}")
        return generated

    def generate_hard_ham(self, output_path="/content/hard_ham_generated_auto.csv"):
        num_ham = self.df[self.df["Category"] == "ham"].shape[0]
        num_spam = self.df[self.df["Category"] == "spam"].shape[0]
        if num_ham >= num_spam:
            n_generate = int((num_ham - num_spam) * self.alpha_ham)
            base_texts = self.df[self.df["Category"] == "ham"]["Message"].sample(n=n_generate, random_state=42).tolist()
            generated = self._generate_sentences(base_texts, self.ham_groups, n_generate)
            pd.DataFrame({"Category": ["ham"] * n_generate, "Message": generated}).to_csv(output_path, index=False)
            print(f"✅ Sinh {n_generate} hard ham -> {output_path}")
            return generated
        else:
            print("✅ Ham đã đủ, không cần sinh thêm.")
            return []

    def generate_synonym_replacement(self, messages, labels, aug_ratio=0.2):
        MAX_AUG = int(len(messages) * aug_ratio)
        augmented_messages, augmented_labels = [], []
        print(f"✅ Synonym Replacement: sinh tối đa {MAX_AUG} câu.")
        for msg, label in zip(messages, labels):
            if len(augmented_messages) >= MAX_AUG:
                break
            if random.random() > 0.8:
                aug_msg = self.synonym_replacement(msg)
                if aug_msg != msg:
                    augmented_messages.append(aug_msg)
                    augmented_labels.append(label)
        print(f"✅ Đã sinh {len(augmented_messages)} câu augmented thực tế.")
        return augmented_messages, augmented_labels

    def synonym_replacement(self, text, n=1):
        words = text.split()
        new_words = words.copy()
        candidates = [w for w in words if wordnet.synsets(w)]
        if not candidates:
            return text
        random.shuffle(candidates)
        replaced_count = 0
        for random_word in candidates:
            synonyms = wordnet.synsets(random_word)
            if synonyms:
                synonym = synonyms[0].lemmas()[0].name().replace('_', ' ')
                if synonym.lower() != random_word.lower():
                    new_words = [synonym if w == random_word else w for w in new_words]
                    replaced_count += 1
            if replaced_count >= n:
                break
        return " ".join(new_words)

    # Tự động sinh test case
    def generate_user_like_spam_ham(self, label='spam', n=10, api_key=None, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
        """
        Sinh ra các câu spam/ham giống như tin nhắn từ người dùng thật có nội dung hỏi hoặc trò chuyện.

        Args:
            label (str): 'spam' hoặc 'ham'.
            n (int): Số lượng cần sinh.
            api_key (str): Together.ai API key.
            model (str): Model ID (Mixtral/Mistral/LLaMA3...)

        Returns:
            List[str]: Danh sách tin nhắn được sinh ra.
        """
        import requests

        if api_key is None:
            raise ValueError("❌ Bạn cần cung cấp Together.ai API key.")

        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }

        prompt_template = {
            "spam": (
                "You are writing deceptive user messages that look like innocent questions, but are actually subtle spam.\n"
                "Generate realistic user messages (in casual style) that include spam signals, but sound like real human questions or messages.\n"
                f"Generate {n} such examples. Output only the messages, one per line."
            ),
            "ham": (
                "You are writing user messages that look like spam at first, but are actually legitimate, honest messages.\n"
                "Generate realistic messages where a user might mention cashback, refund, login alerts, etc. but in a real, harmless context.\n"
                f"Generate {n} such examples. Output only the messages, one per line."
            )
        }

        prompt = prompt_template[label]

        payload = {
            "model": model,
            "prompt": prompt,
            "max_tokens": 1000,
            "temperature": 0.9,
            "top_p": 0.95,
            "stop": None
        }

        response = requests.post("https://api.together.xyz/v1/completions", headers=headers, json=payload)

        if response.ok:
            raw_output = response.json()["choices"][0]["text"].strip()
            # Tách các dòng nếu có xuống dòng
            return [line.strip("-•* ") for line in raw_output.splitlines() if line.strip()]
        else:
            raise RuntimeError(f"Lỗi khi gọi Together API: {response.status_code} - {response.text}")

## Import kaggle dataset

In [6]:
def load_data_from_kaggle():
    """Load Vietnamese spam dataset from Kaggle"""
    df = kagglehub.load_dataset(
        KaggleDatasetAdapter.PANDAS,
        "victorhoward2/vietnamese-spam-post-in-social-network",
        "vi_dataset.csv"
    )
    print(f"Successfully loaded Kaggle dataset with {len(df)} records")
    print("First 5 records:")
    print(df.head())
    return df

## Import google drive (bản mới)

In [8]:
def load_data_from_gdrive():

    # 1. Download file từ GDrive
    !gdown --id 1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R
    DATASET_PATH = '/content/2cls_spam_text_cls.csv'

    # 2. Load dataset gốc
    df_base = pd.read_csv(DATASET_PATH)
    print("First 5 records:")
    print(df_base.head())

    return df_base

## Prepare Data

In [9]:
def preprocess_dataframe(df):
    """Preprocess the loaded dataframe to extract messages and labels"""
    print("Preprocessing dataframe...")
    print(f"Columns available: {list(df.columns)}")

    # Try to identify text and label columns
    text_column = None
    label_column = None

    # Common text column names
    text_candidates = ['message', 'text', 'content', 'email', 'post', 'comment', "texts_vi"]
    for col in df.columns:
        if col.lower() in text_candidates or 'text' in col.lower() or 'message' in col.lower():
            text_column = col
            break

    # Common label column names
    label_candidates = ['label', 'class', 'category', 'type']
    for col in df.columns:
        if col.lower() in label_candidates or 'label' in col.lower():
            label_column = col
            break

    # If not found, use first two columns
    if text_column is None:
        text_column = df.columns[0]
        print(f"Text column not found, using first column: {text_column}")

    if label_column is None:
        label_column = df.columns[1] if len(df.columns) > 1 else df.columns[0]
        print(f"Label column not found, using: {label_column}")

    print(f"Using text column: {text_column}")
    print(f"Using label column: {label_column}")

    # Clean text data
    df[text_column] = df[text_column].astype(str).fillna('')
    df = df[df[text_column].str.strip() != '']  # Remove empty texts

    # Clean labels - convert to ham/spam format
    df[label_column] = df[label_column].astype(str).str.lower()

    # Map various label formats to ham/spam
    label_mapping = {
        '0': 'ham', '1': 'spam',
        'ham': 'ham', 'spam': 'spam',
        'normal': 'ham', 'spam': 'spam',
        'legitimate': 'ham', 'phishing': 'spam',
        'not_spam': 'ham', 'is_spam': 'spam'
    }

    df[label_column] = df[label_column].map(label_mapping).fillna(df[label_column])

    # Show label distribution
    label_counts = df[label_column].value_counts()
    print(f"Label distribution:")
    for label, count in label_counts.items():
        print(f"  {label}: {count} samples")

    messages = df[text_column].tolist()
    labels = df[label_column].tolist()

    print(f"Processed {len(messages)} messages")
    return messages, labels

## Load data

In [10]:
def load_dataset(source='kaggle', file_id=None):
    if source == 'kaggle':
        df = load_data_from_kaggle()
    elif source == 'gdrive':
        # if file_id is None:
        #     file_id = "1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R"  # Default ID
        df = load_data_from_gdrive()
    else:
        raise ValueError("Source must be 'kaggle' or 'gdrive'")

    if df is None:
        raise Exception(f"Failed to load data from {source}")

    messages, labels = preprocess_dataframe(df)
    return messages, labels

## Embedding model

In [11]:
model_name = "intfloat/multilingual-e5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

tokenizer_config.json:   0%|          | 0.00/418 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x XLMRobertaLayer(
        (attention): XLMRobertaAttention(
          (self): XLMRobertaSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): XLMRobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine

In [12]:
def average_pool(last_hidden_states, attention_mask):
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0
    )
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [13]:
def get_embeddings(texts, model, tokenizer, device, batch_size=32):
    embeddings = []

    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i+batch_size]

        batch_texts_with_prefix = [f"passage: {text}" for text in batch_texts]
        batch_dict = tokenizer(batch_texts_with_prefix, max_length=512, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(device) for k, v in batch_dict.items()}

        with torch.no_grad():
            outputs = model(**batch_dict)
            batch_embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
            embeddings.append(batch_embeddings.cpu().numpy())

    return np.vstack(embeddings)

## Handle imblance label data

In [14]:
def calculate_class_weights(labels):
    """Calculate class weights for handling imbalanced data"""
    label_counts = Counter(labels)
    total_samples = len(labels)
    num_classes = len(label_counts)

    class_weights = {}
    for label, count in label_counts.items():
        # Inverse frequency weighting
        class_weights[label] = total_samples / (num_classes * count)

    print("Class distribution:")
    for label, count in label_counts.items():
        print(f"  {label}: {count} samples (weight: {class_weights[label]:.3f})")

    return class_weights

## Compute saliency scores

In [15]:
def compute_saliency_scores(query_text, model, tokenizer, device, index, train_metadata, k=10):
    """Compute saliency scores for explainability"""
    tokens = tokenizer.tokenize(query_text)

    if len(tokens) <= 1:
        return np.array([1.0])

    # Get original embedding and spam score
    query_with_prefix = f"query: {query_text}"
    batch_dict = tokenizer([query_with_prefix], max_length=512, padding=True, truncation=True, return_tensors="pt")
    batch_dict = {k: v.to(device) for k, v in batch_dict.items()}

    with torch.no_grad():
        outputs = model(**batch_dict)
        original_embedding = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
        original_embedding = F.normalize(original_embedding, p=2, dim=1)
        original_embedding = original_embedding.cpu().numpy().astype("float32")

    original_scores, original_indices = index.search(original_embedding, k)
    original_spam_score = sum(s for s, idx in zip(original_scores[0], original_indices[0])
                             if train_metadata[idx]["label"] == "spam")

    saliencies = []

    # Compute saliency for each token
    for i, token in enumerate(tokens):
        token_mask = tokens.copy()
        token_mask[i] = tokenizer.pad_token
        masked_text = tokenizer.convert_tokens_to_string(token_mask)

        masked_query = f"query: {masked_text}"
        masked_batch_dict = tokenizer([masked_query], max_length=512, padding=True, truncation=True, return_tensors="pt")
        masked_batch_dict = {k: v.to(device) for k, v in masked_batch_dict.items()}

        with torch.no_grad():
            outputs = model(**masked_batch_dict)
            masked_embedding = average_pool(outputs.last_hidden_state, masked_batch_dict["attention_mask"])
            masked_embedding = F.normalize(masked_embedding, p=2, dim=1)
            masked_embedding = masked_embedding.cpu().numpy().astype("float32")

        masked_scores, masked_indices = index.search(masked_embedding, k)
        masked_spam_score = sum(s for s, idx in zip(masked_scores[0], masked_indices[0])
                               if train_metadata[idx]["label"] == "spam")

        saliency = original_spam_score - masked_spam_score
        saliencies.append(saliency)

    # Normalize saliencies
    arr = np.array(saliencies)
    if len(arr) > 1:
        arr = (arr - arr.min()) / (np.ptp(arr) + 1e-12)
    else:
        arr = np.array([1.0])

    return arr


## Classification with KNN with weighted

In [16]:
def classify_with_weighted_knn(query_text, model, tokenizer, device, index, train_metadata, class_weights, k=10, alpha=0.5, explain=False):
    """Enhanced KNN classification with custom weighting formula"""

    # Get query embedding
    query_with_prefix = f"query: {query_text}"
    batch_dict = tokenizer([query_with_prefix], max_length=512, padding=True, truncation=True, return_tensors="pt")
    batch_dict = {k: v.to(device) for k, v in batch_dict.items()}

    with torch.no_grad():
        outputs = model(**batch_dict)
        query_embedding = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
        query_embedding = F.normalize(query_embedding, p=2, dim=1)
        query_embedding = query_embedding.cpu().numpy().astype("float32")

    # Get nearest neighbors
    scores, indices = index.search(query_embedding, k)

    # Compute saliency weight
    if explain:
        saliency_scores = compute_saliency_scores(query_text, model, tokenizer, device, index, train_metadata, k)
        saliency_weight = np.mean(saliency_scores)
        tokens = tokenizer.tokenize(query_text)
    else:
        # Quick saliency approximation
        saliency_weight = compute_quick_saliency(query_text)
        saliency_scores = None
        tokens = None

    # Calculate weighted votes
    vote_scores = {"ham": 0.0, "spam": 0.0}
    neighbor_info = []

    for i in range(k):
        neighbor_idx = indices[0][i]
        similarity = float(scores[0][i])
        neighbor_label = train_metadata[neighbor_idx]["label"]
        neighbor_message = train_metadata[neighbor_idx]["message"]

        # Apply custom weighting formula: w = (1-α)×similarity×class_weight + α×saliency_weight
        weight = (1 - alpha) * similarity * class_weights[neighbor_label] + alpha * saliency_weight

        vote_scores[neighbor_label] += weight

        neighbor_info.append({
            "score": similarity,
            "weight": weight,
            "label": neighbor_label,
            "message": neighbor_message[:100] + "..." if len(neighbor_message) > 100 else neighbor_message
        })

    # Get prediction
    predicted_label = max(vote_scores, key=vote_scores.get)

    result = {
        "prediction": predicted_label,
        "vote_scores": vote_scores,
        "neighbors": neighbor_info,
        "saliency_weight": saliency_weight,
        "alpha": alpha
    }

    if explain:
        result["tokens"] = tokens
        result["saliency_scores"] = saliency_scores

    return result

In [17]:
def compute_quick_saliency(text):
    """Enhanced saliency computation for subtle spam detection"""
    words = text.lower().split()
    text_lower = text.lower()

    basic_spam_keywords = [
        'free', 'click', 'urgent', 'limited', 'offer', 'discount', 'sale', 'win', 'prize',
        'money', 'cash', 'earn', 'guaranteed', 'act now', 'call now', 'congratulations',
        'miễn phí', 'khuyến mãi', 'giảm giá', 'ưu đãi', 'thắng', 'giải thưởng', 'tiền',
        'kiếm tiền', 'đảm bảo', 'hành động ngay', 'chúc mừng', 'cơ hội', 'quà tặng'
    ]

    social_engineering_keywords = [
        'mom', 'boss', 'hr', 'manager', 'security update', 'unusual login',
        'hospital bill', 'emergency', 'help buy', 'reimburse', 'gift cards',
        'short-staffed', 'extra shifts', 'card was declined', 'warranty',
        'mẹ', 'sếp', 'nhân sự', 'cập nhật bảo mật', 'đăng nhập bất thường',
        'viện phí', 'khẩn cấp', 'giúp mua', 'hoàn tiền'
    ]

    urgency_patterns = [
        'today', 'tomorrow', 'this week', 'before friday', 'reply yes',
        'cancel anytime', 'confirm before', 'register early', 'already got mine',
        'hôm nay', 'ngày mai', 'tuần này', 'trước thứ sáu', 'trả lời có'
    ]

    money_patterns = [
        r'\$\d+', r'\d+\$', r'\d+\s*dollar', r'\d+\s*usd',
        r'\d+\s*triệu', r'\d+\s*nghìn', r'\d+\s*đồng'
    ]

    suspicious_contexts = [
        'just signed up', 'they refund', 'i already got', 'you should try',
        'can you help', 'reply if', 'book a slot', 'free diagnostics',
        'vừa đăng ký', 'họ hoàn tiền', 'tôi đã nhận', 'bạn nên thử'
    ]

    # Calculate scores
    basic_spam_score = sum(1 for word in words if any(keyword in word for keyword in basic_spam_keywords))
    social_eng_score = sum(2 for keyword in social_engineering_keywords if keyword in text_lower)  # Higher weight
    urgency_score = sum(1.5 for pattern in urgency_patterns if pattern in text_lower)

    # Money pattern detection (regex)
    import re
    money_score = 0
    for pattern in money_patterns:
        if re.search(pattern, text_lower):
            money_score += 2  # High weight for money mentions

    suspicious_score = sum(1.5 for context in suspicious_contexts if context in text_lower)

    length_factor = 1.0
    if len(words) < 5:  # Very short
        length_factor = 0.8
    elif len(words) > 50:  # Very long
        length_factor = 1.2

    # Combined saliency score
    total_score = (basic_spam_score + social_eng_score + urgency_score + money_score + suspicious_score) * length_factor

    # Normalize by text length but with minimum threshold
    saliency = min(1.0, max(0.1, total_score / max(len(words), 1) + 0.2))

    return saliency

## Find the best alpha in weight

In [18]:
def optimize_alpha_parameter(test_embeddings, test_labels, test_metadata, index, train_metadata, class_weights, k=10):
    """Find optimal alpha value for best accuracy"""
    print("Optimizing alpha parameter...")

    alpha_values = np.arange(0.0, 1.1, 0.1)
    best_alpha = 0.0
    best_accuracy = 0.0
    alpha_results = []

    for alpha in tqdm(alpha_values, desc="Testing alpha values"):
        correct = 0
        total = len(test_embeddings)

        for i in range(total):
            query_embedding = test_embeddings[i:i+1].astype("float32")
            true_label = test_metadata[i]["label"]
            query_text = test_metadata[i]["message"]

            # Use weighted classification
            result = classify_with_weighted_knn(
                query_text, model, tokenizer, device, index, train_metadata,
                class_weights, k=k, alpha=alpha, explain=False
            )

            if result["prediction"] == true_label:
                correct += 1

        accuracy = correct / total
        alpha_results.append((alpha, accuracy))

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_alpha = alpha

        print(f"Alpha: {alpha:.1f}, Accuracy: {accuracy:.4f}")

    print(f"\nBest alpha: {best_alpha:.1f} with accuracy: {best_accuracy:.4f}")
    return best_alpha, alpha_results

## Classify sub-category

In [19]:
def classify_spam_subcategory(spam_texts, model, tokenizer, device):
    """Combine BERT embeddings with keyword matching"""
    if not spam_texts:
        return []

    # 1. BERT embeddings
    spam_embeddings = get_embeddings(spam_texts, model, tokenizer, device)

    # 2. Category reference embeddings
    reference_texts = {
        'spam_quangcao': {
            'vietnamese': "khuyến mãi giảm giá sale ưu đãi mua ngay giá rẻ miễn phí quà tặng voucher coupon giải thưởng trúng thưởng cơ hội trúng",
            'english': "promotional discount sale offer prize win money gift free deal bargain cheap special limited",
            'combined': "khuyến mãi giảm giá promotional discount sale ưu đãi offer prize win quà tặng gift free voucher coupon deal bargain trúng thưởng"
        },
        'spam_hethong': {
            'vietnamese': "thông báo cảnh báo tài khoản bảo mật xác nhận cập nhật hệ thống đăng nhập mật khẩu bị khóa hết hạn gia hạn",
            'english': "notification alert account security confirm update system login password locked expired renewal verify suspended warning",
            'combined': "thông báo notification cảnh báo alert tài khoản account bảo mật security xác nhận confirm cập nhật update hệ thống system đăng nhập login mật khẩu password"
        }
    }

    reference_embeddings = {}
    for category, ref_text in reference_texts.items():
        ref_emb = get_embeddings([ref_text], model, tokenizer, device)[0]
        reference_embeddings[category] = ref_emb

    subcategories = []

    for i, (text, text_embedding) in enumerate(zip(spam_texts, spam_embeddings)):
        # A. BERT semantic similarity
        bert_scores = {}
        for category, ref_emb in reference_embeddings.items():
            similarity = np.dot(text_embedding, ref_emb) / (
                np.linalg.norm(text_embedding) * np.linalg.norm(ref_emb)
            )
            bert_scores[category] = similarity

        # B. Keyword matching (existing logic)
        keyword_scores = {}
        text_lower = text.lower()
        category_keywords = {
        'spam_quangcao': [
            # Vietnamese advertising keywords
            'khuyến mãi', 'giảm giá', 'sale', 'ưu đãi', 'mua ngay', 'giá rẻ', 'miễn phí',
            'quà tặng', 'voucher', 'coupon', 'giải thưởng', 'trúng thưởng', 'cơ hội', 'trúng',
            # English advertising keywords
            'discount', 'sale', 'offer', 'promotion', 'free', 'deal', 'buy now', 'limited time',
            'special offer', 'bargain', 'cheap', 'save money', 'win', 'prize', 'gift', 'won',
            'congratulations', 'claim', 'click here', '$', 'money', 'cash'
        ],
        'spam_hethong': [
            # Vietnamese system keywords
            'thông báo', 'cảnh báo', 'tài khoản', 'bảo mật', 'xác nhận', 'cập nhật',
            'hệ thống', 'đăng nhập', 'mật khẩu', 'bị khóa', 'hết hạn', 'gia hạn', 'khóa',
            # English system keywords
            'notification', 'alert', 'account', 'security', 'confirm', 'update',
            'system', 'login', 'password', 'locked', 'expired', 'renewal', 'verify',
            'suspended', 'warning', 'breach', 'urgent', 'immediately'
        ]
        }

        for category, keywords in category_keywords.items():
            score = sum(1 for keyword in keywords if keyword in text_lower)
            keyword_scores[category] = score / len(keywords)  # Normalize

        # C. Combine scores (weighted)
        final_scores = {}
        for category in bert_scores.keys():
            # 70% BERT, 30% keywords
            final_scores[category] = 0.7 * bert_scores[category] + 0.3 * keyword_scores[category]

        # D. Choose best category
        if max(final_scores.values()) < 0.3:  # Low confidence
            best_category = 'spam_khac'
        else:
            best_category = max(final_scores, key=final_scores.get)

        subcategories.append(best_category)

        # Chỗ này thêm vào để log hả ???
        # print(f"Text: {text[:50]}...")
        # print(f"BERT scores: {bert_scores}")
        # print(f"Keyword scores: {keyword_scores}")
        # print(f"Final: {best_category}")

    return subcategories

## Đánh giá model

In [21]:
def evaluate_weighted_knn_accuracy(test_embeddings, test_labels, test_metadata, index, train_metadata, class_weights, alpha, k_values=[1, 3, 5]):
    """Evaluate accuracy using weighted KNN classification"""
    results = {}
    all_errors = {}

    for k in k_values:
        print(f"\nEvaluating with k={k}, alpha={alpha:.1f}")
        correct = 0
        total = len(test_embeddings)
        errors = []

        for i in tqdm(range(total), desc=f"Evaluating k={k}"):
            query_text = test_metadata[i]["message"]
            true_label = test_metadata[i]["label"]

            # Use weighted classification
            result = classify_with_weighted_knn(
                query_text, model, tokenizer, device, index, train_metadata,
                class_weights, k=k, alpha=alpha, explain=False
            )

            predicted_label = result["prediction"]

            if predicted_label == true_label:
                correct += 1
            else:
                error_info = {
                    "index": i,
                    "original_index": test_metadata[i]["index"],
                    "message": query_text,
                    "true_label": true_label,
                    "predicted_label": predicted_label,
                    "vote_scores": result["vote_scores"],
                    "neighbors": result["neighbors"]
                }
                errors.append(error_info)

        accuracy = correct / total
        results[k] = accuracy
        all_errors[k] = errors

        print(f"Accuracy with k={k}: {accuracy:.4f} ({accuracy*100:.2f}%)")
        print(f"Errors: {len(errors)}/{total}")

    return results, all_errors

In [22]:
def enhanced_spam_classifier_pipeline(user_input, index, train_metadata, class_weights, best_alpha, k=5, explain=False):
    """Enhanced spam classification with custom weighting and subcategorization"""

    print(f'\n***Classifying: "{user_input}"')
    print(f"***Using alpha={best_alpha:.1f}, k={k}")

    # Get prediction with weighted KNN
    result = classify_with_weighted_knn(
        user_input, model, tokenizer, device, index, train_metadata,
        class_weights, k=k, alpha=best_alpha, explain=explain
    )

    prediction = result["prediction"]
    vote_scores = result["vote_scores"]
    neighbors = result["neighbors"]

    print(f"***Prediction: {prediction.upper()}")
    print(f"***Vote Scores: Ham={vote_scores['ham']:.3f}, Spam={vote_scores['spam']:.3f}")
    print(f"***Saliency Weight: {result['saliency_weight']:.3f}")

    # If spam, classify subcategory
    subcategory = None
    if prediction == "spam":
        subcategories = classify_spam_subcategory([user_input], model, tokenizer, device)
        subcategory = subcategories[0] if subcategories else "spam_khac"
        print(f"***Spam Subcategory: {subcategory}")

    print("\n***Top neighbors:")
    for i, neighbor in enumerate(neighbors, 1):
        print(f"{i}. Label: {neighbor['label']} | Similarity: {neighbor['score']:.4f} | Weight: {neighbor['weight']:.4f}")
        print(f"   Message: {neighbor['message']}")
        print()

    final_result = {
        "prediction": prediction,
        "subcategory": subcategory,
        "vote_scores": vote_scores,
        "neighbors": neighbors,
        "saliency_weight": result["saliency_weight"],
        "alpha": best_alpha
    }

    if explain and result.get("tokens"):
        final_result["tokens"] = result["tokens"]
        final_result["saliency_scores"] = result["saliency_scores"]

    return final_result

##Run Pipeline (Bản mới)


In [24]:
def run_enhanced_pipeline(messages, labels, test_size=0.2, use_augmentation=True):
    """Run the complete enhanced spam classification pipeline"""

    print("=== Enhanced Spam Classification Pipeline ===")

    # 1. THÊM DATA AUGMENTATION
    if use_augmentation:
        print("\n=== Data Augmentation ===")

        try:
            # 1.1. Hỏi người dùng chọn cách augmentation
            print("Chọn cách data augmentation:")
            print("1. Sinh 70 câu tinh vi bằng LLM (dùng API Together.ai)")
            print("2. Dùng cụm câu có sẵn trong code (không cần mạng/API)")
            aug_mode = input("👉 Nhập 1 hoặc 2: ").strip()

            use_llm = aug_mode == "1"
            gen = HardExampleGenerator(
                dataset_path="/content/2cls_spam_text_cls.csv",
                alpha_spam=1.0,
                alpha_ham=0.3,
                use_llm_phrases=use_llm
            )

            if use_llm:
                api_key = input("🔑 Nhập Together.ai API key (nhấn Enter để bỏ qua): ").strip()
                if api_key:
                    llm_spam = gen.generate_like_spam_ham(label='spam', n_per_group=10, api_key=api_key)
                    llm_ham = gen.generate_like_spam_ham(label='ham', n_per_group=10, api_key=api_key)
                    gen.load_llm_phrases(spam_list=llm_spam, ham_list=llm_ham)
                else:
                    print("⚠️ Không có API key. Sử dụng cụm có sẵn.")
                    gen.spam_groups = gen._init_spam_phrases()
                    gen.ham_groups = gen._init_ham_phrases()
            else:
                print("ℹ️ Sử dụng cụm đã được hardcode trong class.")

            # 1.2. Sinh dữ liệu
            gen.generate_hard_spam("/content/hard_spam_generated_auto.csv")
            gen.generate_hard_ham("/content/hard_ham_generated_auto.csv")
            augmented_messages, augmented_labels = gen.generate_synonym_replacement(messages, labels, aug_ratio=0.2)

            # 1.3. Gộp tất cả lại thành 1 DataFrame mới
            df_base = gen.df
            df_hard_spam = pd.read_csv("/content/hard_spam_generated_auto.csv")
            df_hard_ham = pd.read_csv("/content/hard_ham_generated_auto.csv")
            df_synonym = pd.DataFrame({"Category": augmented_labels, "Message": augmented_messages})
            df = pd.concat([df_base, df_hard_spam, df_hard_ham, df_synonym], ignore_index=True)

            print(f"📈 Tổng dữ liệu sau augmentation: {len(df)} samples.")

            # 1.4. Cập nhật messages & labels
            messages = df["Message"].tolist()
            labels = df["Category"].tolist()

        except Exception as e:
            print(f"⚠️ Augmentation failed: {e}")
            print("ℹ️ Tiếp tục với dữ liệu gốc...")
            df = pd.read_csv("/content/2cls_spam_text_cls.csv")
            messages = df["Message"].tolist()
            labels = df["Category"].tolist()
    else:
        print("ℹ️ Data augmentation disabled")
        df = pd.read_csv("/content/2cls_spam_text_cls.csv")
        messages = df["Message"].tolist()
        labels = df["Category"].tolist()

    # 1.5. Sau augmentation (hoặc không), encode label
    le = LabelEncoder()
    y = le.fit_transform(labels)

    # 2. Generate embeddings
    print("Generating embeddings...")
    X_embeddings = get_embeddings(messages, model, tokenizer, device)

    # 3. Create metadata
    metadata = [{"index": i, "message": message, "label": label, "label_encoded": y[i]}
                for i, (message, label) in enumerate(zip(messages, labels))]

    # 4. Train-test split
    X_train_emb, X_test_emb, train_metadata, test_metadata = train_test_split(
        X_embeddings, metadata, test_size=test_size, random_state=42,
        stratify=[m["label"] for m in metadata]
    )

    # 5. Create FAISS index
    print("Creating FAISS index...")
    dimension = X_train_emb.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(X_train_emb.astype("float32"))

    # 6. Calculate class weights
    train_labels = [m["label"] for m in train_metadata]
    class_weights = calculate_class_weights(train_labels)

    # 7. Optimize alpha parameter
    test_labels = [m["label"] for m in test_metadata]
    best_alpha, alpha_results = optimize_alpha_parameter(
        X_test_emb, test_labels, test_metadata, index, train_metadata, class_weights
    )

    # 8. Final evaluation
    print("\n=== Final Evaluation ===")
    accuracy_results, error_results = evaluate_weighted_knn_accuracy(
        X_test_emb, test_labels, test_metadata, index, train_metadata,
        class_weights, best_alpha, k_values=[1, 3, 5]
    )

    # 9. Analyze spam subcategories
    spam_texts = [m["message"] for m in test_metadata if m["label"] == "spam"]
    if spam_texts:
        print(f"\n=== Spam Subcategory Analysis ===")
        spam_subcategories = classify_spam_subcategory(spam_texts, model, tokenizer, device)
        subcat_counts = Counter(spam_subcategories)

        print("Spam subcategory distribution:")
        for subcat, count in subcat_counts.items():
            print(f"  {subcat}: {count} ({count/len(spam_texts)*100:.1f}%)")

    # 10. Save results
    results = {
        "timestamp": datetime.now().isoformat(),
        "model": model_name,
        "test_size": len(X_test_emb),
        "best_alpha": best_alpha,
        "alpha_results": alpha_results,
        "accuracy_results": accuracy_results,
        "class_weights": class_weights,
        "spam_subcategories": dict(subcat_counts) if spam_texts else {}
    }

    with open("enhanced_results.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print(f"\n*** Results saved to enhanced_results.json ***")

    return {
    "index": index,
    "train_metadata": train_metadata,
    "test_metadata": test_metadata,
    "class_weights": class_weights,
    "best_alpha": best_alpha,
    "results": results,
}

## Apply test case

In [26]:
from IPython.display import HTML, display

def render_heatmap(tokens, saliencies):
  html = ""
  for token, score in zip(tokens, saliencies):
    html += f"<span style='background: rgba(255,0,0,{score:.2f})'>{token}</span> "
  return HTML(f"<div style='font-size:16px; line-height:1.5'>{html}</div>")

def test_enhanced_classifier(pipeline_results):
    index = pipeline_results["index"]
    train_metadata = pipeline_results["train_metadata"]
    class_weights = pipeline_results["class_weights"]
    best_alpha = pipeline_results["best_alpha"]

    test_cases_combined = [
        {"message": "Congratulations! You've won $1000! Click here to claim your prize now!", "expected": "spam"},
        {"message": "URGENT: Your account will be suspended. Verify immediately to avoid closure.", "expected": "spam"},
        {"message": "Thanks for your help with the project. The presentation went very well.", "expected": "ham"},
        {"message": "Chào bạn, bạn có khỏe không? Ngày mai mình gặp nhau uống cà phê nhé?", "expected": "ham"},
        {"message": "Cuộc họp đã được dời lại lúc 3 giờ chiều. Bạn vui lòng cập nhật lịch.", "expected": "ham"},
        {"message": "Cảm ơn bạn đã giúp đỡ dự án. Buổi thuyết trình đã diễn ra rất tốt.", "expected": "ham"},
    ]

    api_key = input("Nhập Together.ai API key (nhấn Enter để bỏ qua): ").strip()

    # Khởi tạo generator
    use_llm = bool(api_key)
    gen = HardExampleGenerator("/content/2cls_spam_text_cls.csv", use_llm_phrases=use_llm)


    if api_key:
        print("\n\n=== Generating LLM-based Test Cases ===")
        try:
            spam_like_questions = gen.generate_user_like_spam_ham(
                label="spam", n=5, api_key=api_key, model="mistralai/Mixtral-8x7B-Instruct-v0.1"
            )
            ham_legit_questions = gen.generate_user_like_spam_ham(
                label="ham", n=5, api_key=api_key, model="mistralai/Mixtral-8x7B-Instruct-v0.1"
            )
            for msg in spam_like_questions:
                test_cases_combined.append({"message": msg, "expected": "spam"})
            for msg in ham_legit_questions:
                test_cases_combined.append({"message": msg, "expected": "ham"})
            print(" LLM-generated examples created.")
        except Exception as e:
            print(f" Lỗi sinh ví dụ LLM: {e}")
    else:
        print("ℹ Không có API key. Sử dụng test case thủ công.")

    print("\n\n=== BẮT ĐẦU TEST ===")
    for i, test_case in enumerate(test_cases_combined, 1):
        message = test_case["message"]
        expected_label = test_case["expected"]

        print(f"\n--- Example {i}: {message[:50]}... ---")

        result = enhanced_spam_classifier_pipeline(
            message, index, train_metadata, class_weights, best_alpha, k=5, explain=True
        )

        predicted_label = result["prediction"]
        print(f" Prediction: {predicted_label} (Expected: {expected_label})")

        if "tokens" in result and "saliency_scores" in result:
            print(" Saliency Heatmap:")
            display(render_heatmap(result["tokens"], result["saliency_scores"]))
        else:
            print(" Không thể hiển thị heatmap.")


In [27]:
def integrate_with_existing_data(messages, labels):
    """
    Integrate with your existing messages and labels data

    Args:
        messages: List of email texts
        labels: List of 'ham'/'spam' labels
    """
    print("=== Starting Enhanced Pipeline ===")

    # Run the enhanced pipeline
    pipeline_results = run_enhanced_pipeline(messages, labels)

    # Test with examples
    test_enhanced_classifier(pipeline_results)

    return pipeline_results


In [None]:
if __name__ == "__main__":
    TEST_KAGGLE = False
    TEST_GDRIVE = True

    results_summary = {}

    # Test 1: Kaggle Dataset
    if TEST_KAGGLE:
        print("\n" + "="*60)
        print("TESTING WITH KAGGLE DATASET")
        print("="*60)


        messages, labels = load_dataset(source='kaggle')

        # Run enhanced pipeline
        pipeline_results = run_enhanced_pipeline(messages, labels, test_size=0.2, use_augmentation=True)

        results_summary['kaggle'] = {
            'samples': len(messages),
            'best_alpha': pipeline_results['best_alpha'],
            'accuracy': pipeline_results['results']['accuracy_results']
        }

        # Test with examples
        test_enhanced_classifier(pipeline_results)



    # Test 2: Google Drive Dataset
    if TEST_GDRIVE:
        print("\n" + "="*60)
        print("TESTING WITH GOOGLE DRIVE DATASET")
        print("="*60)

        messages, labels = load_dataset(source='gdrive', file_id='1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R')

        # Run enhanced pipeline
        pipeline_results = run_enhanced_pipeline(messages, labels, test_size=0.2, use_augmentation=True)
        results_summary['gdrive'] = {
            'samples': len(messages),
            'best_alpha': pipeline_results['best_alpha'],
            'accuracy': pipeline_results['results']['accuracy_results']
        }

        # Test with examples
        test_enhanced_classifier(pipeline_results)



    # Final Summary
    print("\n" + "="*60)
    print("FINAL RESULTS SUMMARY")
    print("="*60)

    for dataset, results in results_summary.items():
        print(f"\n{dataset.upper()} Dataset:")
        if 'error' in results:
            print(f"Error: {results['error']}")
        else:
            print(f"Samples: {results['samples']}")
            print(f"Best Alpha: {results['best_alpha']:.1f}")
            print(f"Accuracy Results:")
            for k, acc in results['accuracy'].items():
                print(f"      k={k}: {acc:.4f} ({acc*100:.2f}%)")



TESTING WITH GOOGLE DRIVE DATASET
Downloading...
From: https://drive.google.com/uc?id=1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R
To: /content/2cls_spam_text_cls.csv
100% 486k/486k [00:00<00:00, 80.8MB/s]
First 5 records:
  Category                                            Message
0      ham  Go until jurong point, crazy.. Available only ...
1      ham                      Ok lar... Joking wif u oni...
2     spam  Free entry in 2 a wkly comp to win FA Cup fina...
3      ham  U dun say so early hor... U c already then say...
4      ham  Nah I don't think he goes to usf, he lives aro...
Preprocessing dataframe...
Columns available: ['Category', 'Message']
Using text column: Message
Using label column: Category
Label distribution:
  ham: 4825 samples
  spam: 747 samples
Processed 5572 messages
=== Enhanced Spam Classification Pipeline ===

=== Data Augmentation ===
Chọn cách data augmentation:
1. Sinh 70 câu tinh vi bằng LLM (dùng API Together.ai)
2. Dùng cụm câu có sẵn trong code (không cần mạn

Generating embeddings: 100%|██████████| 372/372 [00:57<00:00,  6.47it/s]


Creating FAISS index...
Class distribution:
  spam: 3975 samples (weight: 1.197)
  ham: 5541 samples (weight: 0.859)
Optimizing alpha parameter...


Testing alpha values:   9%|▉         | 1/11 [00:29<04:52, 29.21s/it]

Alpha: 0.0, Accuracy: 0.9626


Testing alpha values:  18%|█▊        | 2/11 [00:58<04:23, 29.22s/it]

Alpha: 0.1, Accuracy: 0.9626


Testing alpha values:  27%|██▋       | 3/11 [01:27<03:51, 28.98s/it]

Alpha: 0.2, Accuracy: 0.9626


Testing alpha values:  36%|███▋      | 4/11 [02:05<03:49, 32.83s/it]

Alpha: 0.3, Accuracy: 0.9626


Testing alpha values:  45%|████▌     | 5/11 [02:50<03:42, 37.04s/it]

Alpha: 0.4, Accuracy: 0.9626


Testing alpha values:  55%|█████▍    | 6/11 [03:26<03:03, 36.63s/it]

Alpha: 0.5, Accuracy: 0.9626


## Kết quả dùng LLM hoàn toàn tự động

In [None]:
if __name__ == "__main__":
    TEST_KAGGLE = False
    TEST_GDRIVE = True

    results_summary = {}

    # Test 1: Kaggle Dataset
    if TEST_KAGGLE:
        print("\n" + "="*60)
        print("TESTING WITH KAGGLE DATASET")
        print("="*60)


        messages, labels = load_dataset(source='kaggle')

        # Run enhanced pipeline
        pipeline_results = run_enhanced_pipeline(messages, labels, test_size=0.2, use_augmentation=True)

        results_summary['kaggle'] = {
            'samples': len(messages),
            'best_alpha': pipeline_results['best_alpha'],
            'accuracy': pipeline_results['results']['accuracy_results']
        }

        # Test with examples
        test_enhanced_classifier(pipeline_results)



    # Test 2: Google Drive Dataset
    if TEST_GDRIVE:
        print("\n" + "="*60)
        print("TESTING WITH GOOGLE DRIVE DATASET")
        print("="*60)

        messages, labels = load_dataset(source='gdrive', file_id='1N7rk-kfnDFIGMeX0ROVTjKh71gcgx-7R')

        # Run enhanced pipeline
        pipeline_results = run_enhanced_pipeline(messages, labels, test_size=0.2, use_augmentation=True)
        results_summary['gdrive'] = {
            'samples': len(messages),
            'best_alpha': pipeline_results['best_alpha'],
            'accuracy': pipeline_results['results']['accuracy_results']
        }

        # Test with examples
        test_enhanced_classifier(pipeline_results)



    # Final Summary
    print("\n" + "="*60)
    print("FINAL RESULTS SUMMARY")
    print("="*60)

    for dataset, results in results_summary.items():
        print(f"\n{dataset.upper()} Dataset:")
        if 'error' in results:
            print(f"Error: {results['error']}")
        else:
            print(f"Samples: {results['samples']}")
            print(f"Best Alpha: {results['best_alpha']:.1f}")
            print(f"Accuracy Results:")
            for k, acc in results['accuracy'].items():
                print(f"      k={k}: {acc:.4f} ({acc*100:.2f}%)")