In [None]:
import os
import json
from collections import defaultdict
import emoji
import re
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence
from sklearn import metrics
import matplotlib.pyplot as plt
from tqdm import tqdm
from tabulate import tabulate
from colorama import init, Fore, Style

In [None]:
data_dir = "data/task1/train/subjects"
subjects = defaultdict(list)

def extract_emojis(text):
    return set(emoji_data['emoji'] for emoji_data in emoji.emoji_list(text))

all_emojis = set()
for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        with open(os.path.join(data_dir, filename), "r") as f:
            messages = json.load(f)
            nick = filename.split(".")[0]
            for msg in messages:
                message_text = str(msg["message"]) if msg["message"] is not None else ""
                all_emojis.update(extract_emojis(message_text))

Tất cả các biểu tượng cảm xúc trong dữ liệu:
🫀 - :anatomical_heart:
😱 - :face_screaming_in_fear:
🇮🇳 - :India:
🎉 - :party_popper:
🐬 - :dolphin:
🥳 - :partying_face:
😉 - :winking_face:
⛔ - :no_entry:
🚀 - :rocket:
😶 - :face_without_mouth:
👌🏼 - :OK_hand_medium-light_skin_tone:
👆 - :backhand_index_pointing_up:
😠 - :angry_face:
🌚 - :new_moon_face:
😕 - :confused_face:
🤦🏻‍♂ - :man_facepalming_light_skin_tone:
👆🏽 - :backhand_index_pointing_up_medium_skin_tone:
🐻 - :bear:
😴 - :sleeping_face:
💶 - :euro_banknote:
😔 - :pensive_face:
🙈 - :see-no-evil_monkey:
🥺 - :pleading_face:
😪 - :sleepy_face:
🦖 - :T-Rex:
💰 - :money_bag:
🥚 - :egg:
🍎 - :red_apple:
✅ - :check_mark_button:
😮 - :face_with_open_mouth:
☹ - :frowning_face:
👏 - :clapping_hands:
♥ - :heart_suit:
😨 - :fearful_face:
🤦🏻‍♀ - :woman_facepalming_light_skin_tone:
🔪 - :kitchen_knife:
❗ - :red_exclamation_mark:
🐮 - :cow_face:
🎈 - :balloon:
📊 - :bar_chart:
👏🏻 - :clapping_hands_light_skin_tone:
📈 - :chart_increasing:
🟣 - :purple_circle:
😳 - :flushed_f

In [None]:
def map_emoji_to_spanish(emoji=None):
    emoji_map = {
        "🔝": "arriba",
        "👎": "no me gusta",
        "😳": "sorprendido",
        "4️⃣": "cuatro",
        "🖐🏼": "mano abierta",
        "💎": "diamante",
        "🤣": "riendo fuerte",
        "🤞🏻": "dedos cruzados",
        "🍺": "cerveza",
        "❣": "corazón exclamación",
        "🤡": "payaso",
        "🎅🏻": "Papá Noel",
        "⬆": "subir",
        "💸": "dinero volando",
        "🤤": "babeando",
        "❌": "cruz",
        "🙌🏻": "manos arriba",
        "🤩": "asombrado",
        "🇵🇪": "Perú",
        "🤠": "vaquero",
        "🟣": "círculo morado",
        "🖐🏽": "mano abierta",
        "🙃": "cara invertida",
        "🐸": "rana",
        "👆🏼": "señalando arriba",
        "🈚": "gratis",
        "🌐": "mundo",
        "🎁": "regalo",
        "🎉": "celebración",
        "😵‍💫": "mareado",
        "🌝": "luna llena",
        "🙋‍♂": "hombre levantando mano",
        "3️⃣": "tres",
        "🔮": "bola de cristal",
        "😰": "nervioso",
        "😨": "miedo",
        "❓": "pregunta",
        "☝🏻": "dedo arriba",
        "🥲": "lágrimas de alegría",
        "✊🏼": "puño levantado",
        "✊": "puño",
        "🧘🏻‍♂": "meditación",
        "🧐": "curioso",
        "👏🏾": "aplausos",
        "🐳": "ballena",
        "💪🏼": "fuerza",
        "✅": "aprobado",
        "🤦🏼‍♂": "vergüenza",
        "😍": "enamorado",
        "👻": "fantasma",
        "😂": "riendo",
        "💪🏻": "fuerte",
        "🫤": "decepción",
        "⚽": "fútbol",
        "🥚": "huevo",
        "🙏": "rezando",
        "🤙": "llámame",
        "🙄": "aburrido",
        "😲": "asombro",
        "♥": "corazón",
        "🍎": "manzana",
        "🐻": "oso",
        "🤪": "loco",
        "👆🏽": "señalando arriba",
        "🎢": "montaña rusa",
        "🙌": "celebrando",
        "🌘": "luna menguante",
        "🫡": "saludo",
        "🙋🏻‍♀": "mujer levantando mano",
        "🤦‍♂": "error",
        "🌊": "ola",
        "😉": "guiño",
        "🥶": "frío",
        "💋": "beso",
        "🇺🇦": "Ucrania",
        "😶‍🌫": "confundido",
        "🌬": "viento",
        "💩": "mierda",
        "👌🏼": "perfecto",
        "🙆‍♂": "hombre OK",
        "💪🏽": "fuerza",
        "😱": "gritando",
        "1️⃣": "uno",
        "🤘": "rock",
        "👉": "señalando derecha",
        "🙂": "sonriendo",
        "👁": "ojo",
        "👀": "ojos",
        "🔥": "fuego",
        "⏺": "grabar",
        "😅": "sudando",
        "❗": "exclamación",
        "😕": "confuso",
        "🥒": "pepino",
        "🎂": "torta",
        "😥": "aliviado",
        "✌🏽": "victoria",
        "🎾": "tenis",
        "💚": "corazón verde",
        "💔": "corazón roto",
        "👍": "bien",
        "🐶": "perro",
        "✔": "verificado",
        "✌🏻": "paz",
        "💪": "músculo",
        "🎈": "globo",
        "🤑": "dinero en la cara",
        "😾": "gato enfadado",
        "💵": "billete",
        "👋🏻": "saludando",
        "👈🏻": "señalando izquierda",
        "💰": "bolsa de dinero",
        "🎼": "música",
        "🐮": "vaca",
        "🇦🇷": "Argentina",
        "🤷🏼‍♀": "mujer encogiéndose",
        "💃": "bailando",
        "🤮": "vomitando",
        "🇷🇺": "Rusia",
        "😎": "genial",
        "🥳": "fiesta",
        "⚰": "ataúd",
        "💯": "cien puntos",
        "📈": "gráfico subiendo",
        "😭": "llorando",
        "😪": "somnoliento",
        "🤞🏼": "suerte",
        "🤦🏽‍♂": "hombre avergonzado",
        "▶": "reproducir",
        "⛔": "prohibido",
        "🎶": "notas musicales",
        "🙊": "mono callado",
        "🌚": "luna nueva",
        "👏": "aplaudiendo",
        "🙏🏽": "rezando",
        "😄": "feliz",
        "🤦🏻‍♂": "error hombre",
        "🇨🇳": "China",
        "👌🏻": "OK",
        "🤙🏻": "llámame",
        "🇳🇬": "Nigeria",
        "😃": "alegre",
        "ℹ️": "información",
        "🗣": "hablando",
        "🙌🏼": "manos levantadas",
        "🤞": "cruzando dedos",
        "😜": "broma",
        "🎵": "nota musical",
        "🤟": "te amo",
        "✈": "avión",
        "👌🏽": "perfecto",
        "🤦🏽": "vergüenza",
        "👍🏾": "bien",
        "🔹": "diamante azul",
        "😝": "lengua fuera",
        "💶": "euro",
        "🤓": "nerd",
        "😶": "sin expresión",
        "🐁": "ratón",
        "🐗": "jabalí",
        "🤦🏻‍♀": "mujer avergonzada",
        "🍏": "manzana verde",
        "🟢": "círculo verde",
        "🙌🏽": "celebración",
        "🇪🇸": "España",
        "✨": "brillo",
        "🤷🏻‍♂": "hombre encogiéndose",
        "🚨": "alarma",
        "🥰": "amor",
        "☺": "sonrisa",
        "🤷‍♂": "duda",
        "🤯": "cabeza explotando",
        "🥺": "suplicando",
        "🐟": "pez",
        "🇮🇳": "India",
        "😐": "neutral",
        "😁": "sonriendo amplio",
        "🙋🏻‍♂": "levantando mano",
        "😓": "sudor",
        "🕺": "bailando",
        "😯": "sorprendido",
        "👉🏻": "señalando derecha",
        "💥": "explosión",
        "😢": "llorando",
        "🦖": "T-Rex",
        "⚡": "rayo",
        "😴": "durmiendo",
        "🫣": "espiando",
        "😻": "gato enamorado",
        "🥵": "caliente",
        "👍🏻": "pulgar arriba",
        "🇧🇾": "Bielorrusia",
        "🤷🏽‍♀": "mujer dudando",
        "😋": "saboreando",
        "🚫": "prohibido",
        "👅": "lengua",
        "😆": "riendo mucho",
        "😊": "sonriendo feliz",
        "😇": "ángel",
        "😠": "enojado",
        "🌎": "Américas",
        "⬇": "bajar",
        "😞": "triste",
        "🔵": "círculo azul",
        "📨": "correo",
        "👆": "arriba",
        "😘": "besando",
        "🌖": "luna gibosa",
        "❤": "corazón rojo",
        "☝": "dedo arriba",
        "✌": "victoria",
        "🍻": "brindis",
        "🤝": "apretón de manos",
        "👋": "saludo",
        "💲": "dólar",
        "👍🏼": "bien",
        "🚶🏻‍♂": "hombre caminando",
        "🤔": "pensando",
        "😹": "gato riendo",
        "🫵": "señalando",
        "🤭": "riendo callado",
        "🪂": "paracaídas",
        "😈": "diablo",
        "🔰": "principiante",
        "🫀": "corazón",
        "😒": "molesto",
        "🤷": "no sé",
        "😀": "felicidad",
        "🍀": "trébol",
        "🔪": "cuchillo",
        "😮": "boca abierta",
        "💬": "hablar",
        "✋": "mano levantada",
        "😌": "alivio",
        "💦": "sudor",
        "🤷🏼‍♂": "duda",
        "☹": "tristeza",
        "🤨": "sospecha",
        "🤙🏽": "llámame",
        "🔻": "triángulo abajo",
        "🛍": "compras",
        "🤧": "estornudo",
        "💫": "mareo",
        "👼": "ángel",
        "🤌": "pellizco",
        "💨": "rápido",
        "😛": "lengua fuera",
        "🎄": "árbol de Navidad",
        "🥹": "lágrimas contenidas",
        "☀": "sol",
        "🌕": "luna llena",
        "🇺🇸": "Estados Unidos",
        "👏🏼": "aplausos",
        "‼": "doble exclamación",
        "🚀": "cohete",
        "😡": "furioso",
        "😬": "nervios",
        "🔴": "círculo rojo",
        "🙏🏻": "orando",
        "🙈": "mono tapándose",
        "🦥": "perezoso",
        "🌙": "luna creciente",
        "👈": "señalando izquierda",
        "🐷": "cerdo",
        "🥸": "disfrazado",
        "😏": "sonrisa pícara",
        "😚": "beso cerrado",
        "⚓": "ancla",
        "👌": "OK",
        "🤟🏻": "te amo",
        "🌌": "vía láctea",
        "⚠": "advertencia",
        "🥱": "bostezando",
        "🐬": "delfín",
        "📊": "gráfico",
        "🐀": "rata",
        "🤗": "abrazo",
        "😔": "pensativo",
        "👏🏻": "aplaudiendo",
        "🇧🇬": "Bulgaria",
        "🥴": "mareado"
    }
    if emoji is None:
        return emoji_map  # Trả về toàn bộ từ điển nếu không truyền emoji
    return emoji_map.get(emoji, emoji)  # Trả về ánh xạ hoặc emoji gốc

def replace_emojis_in_text(text):
    result = text
    for emoji, spanish_text in map_emoji_to_spanish().items():
        result = result.replace(emoji, f" {spanish_text} ")
    words = result.split()
    if not words:
        return ""
    cleaned_words = [words[0]]
    for i in range(1, len(words)):
        if words[i] != words[i-1]:
            cleaned_words.append(words[i])
    return " ".join(cleaned_words).strip()

for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        with open(os.path.join(data_dir, filename), "r") as f:
            messages = json.load(f)
            nick = filename.split(".")[0]
            subjects[nick] = [
                replace_emojis_in_text(str(msg["message"]) if msg["message"] is not None else "")
                for msg in messages
            ]

In [6]:
subjects['user10343']

['Voy cargando la escopeta pa longuear alguna monedilla',
 'Todavía no',
 'Esto está más aburrido',
 'Falta mucho para que entren los chinos ?',
 'Por señales de Facu ?',
 'Compartí ahora bien',
 'Y compartilos ahora . Igual te van a banear cuando lean esto riendo',
 'Sacate la papa de la boca',
 'Se picó la clande',
 'riendo',
 'Screenshot ( 6 mar . 2022 20:03 : 38 )',
 'La paciencia sudando',
 'Divino ! aplaudiendo',
 'Screenshot ( 6 mar . 2022 20:22 : 39 )',
 'Cerrada lunitaa enamorado',
 'mono tapándose',
 'Un short en 38k es lo mismo que ir al casino riendo',
 'Recuperar que ?',
 'Si no entra una ballena generosa en los próximos minutos btc se pega un palo',
 'Vamoooo mañana se come enamorado',
 'De proyectos no se absolutamente nada sonriendo feliz solo tradeo cualquier cosa que se mueva riendo',
 'riendo aplaudiendo',
 '4 k en un trade ? Mierda ! Que haces acá hermano',
 'Anoche no estaban meta long todos ?',
 'Yo vi que subian longs a dos manos .. pero no recuerdo si era a anoc

In [None]:
task1_labels = {}
with open("data/task1/train/gold_task1.txt", "r") as f:
    next(f)
    for line in f:
        nick, risk = line.strip().split(",")
        task1_labels[nick] = int(risk)

messages = []
labels = []
for nick, subject_messages in subjects.items():
    if nick in task1_labels:
        if task1_labels[nick] == 0:
            msg_len = len(subject_messages)
            if msg_len > 1:
                split_point = msg_len // 2
                messages.append(subject_messages[:split_point])
                messages.append(subject_messages[split_point:])
                labels.extend([0, 0])
            else:
                messages.append(subject_messages)
                labels.append(0)
        else:
            messages.append(subject_messages)
            labels.append(task1_labels[nick])

Tổng số mẫu: 528, Phân bố nhãn: [356 172]


In [None]:
class EmbDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


In [None]:
class EmbDatasetRNNAug(Dataset):
    def __init__(self, embeddings, labels, thr_rng=0.7, n_msg=10):
        self.embeddings = embeddings
        self.labels = labels
        self.emb0 = [embeddings[i] for i in range(len(labels)) if labels[i] == 0]
        self.thr_rng = thr_rng
        self.n_msg = n_msg

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.labels[idx] == 0:
            nr_msg = np.random.randint(1, len(self.embeddings[idx]) + 1)
            return self.embeddings[idx][:nr_msg], self.labels[idx]
        else:
            rnd = np.random.uniform()
            if rnd > self.thr_rng:
                neutral = self.emb0[np.random.randint(0, len(self.emb0))]
                n_extra = np.random.randint(1, min(len(neutral), self.n_msg))
                return np.concatenate([neutral[:n_extra], self.embeddings[idx]], axis=0), self.labels[idx]
            return self.embeddings[idx], self.labels[idx]

In [None]:
def make_plot(train_scores, val_scores, y_label, figsize=(8,5)):
    fig, ax = plt.subplots(1,1,figsize=figsize)
    ax.plot(train_scores, label='Train')
    ax.plot(val_scores, label='Val')
    ax.set_xlabel('Epoch')
    ax.set_ylabel(y_label)
    ax.legend()
    return fig, ax

In [None]:
def get_cls_embeddings(all_messages, model, tokenizer, device, m_length=96):
    model.to(device)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for subject_messages in tqdm(all_messages):
            input = tokenizer(subject_messages, padding=True, truncation=True, max_length=m_length, return_tensors='pt')
            output = model(**input.to(device))
            embeddings.append(output.last_hidden_state[:, 0, :].cpu().numpy())
    return embeddings

In [None]:
def lstm_collate(batch):
    labels = [x[1] for x in batch]
    labels = torch.tensor(labels, dtype=torch.long)
    data = [torch.tensor(x[0], dtype=torch.float32) for x in batch]
    batch_data = pad_sequence(data)
    lens = torch.tensor([len(x) for x in data], dtype=torch.long).unsqueeze(0).unsqueeze(-1)
    lens -= 1
    return batch_data, lens, labels


In [None]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, h_size, output_dim, dropout=0):
        super().__init__()
        self.lstm = nn.LSTM(input_size, h_size, num_layers=1, batch_first=False, dropout=dropout, bidirectional=True)
        self.attention = nn.Linear(2 * h_size, 1)
        self.classifier = nn.Linear(2 * h_size, output_dim)

    def forward(self, seq_data, seq_lens, state=None):
        lstm_out, _ = self.lstm(seq_data)
        attention_weights = torch.softmax(self.attention(lstm_out), dim=0)
        context_vector = torch.sum(attention_weights * lstm_out, dim=0)
        return self.classifier(context_vector)

    def predict_all_timesteps(self, seq_data, seq_lens, state=None):
        lstm_out, _ = self.lstm(seq_data)
        attention_weights = torch.softmax(self.attention(lstm_out), dim=0)
        logits_all = self.classifier(lstm_out)
        pred_all = torch.argmax(logits_all, dim=2)
        ts_predictions = [pred_all[:seq_lens[0, i].item(), i].squeeze().cpu().numpy() for i in range(pred_all.shape[1])]
        return ts_predictions

In [13]:
# Thiết lập và huấn luyện
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
name = 'pysentimiento/robertuito-sentiment-analysis'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModel.from_pretrained(name)
embs = get_cls_embeddings(messages, model, tokenizer, device, m_length=96)

Some weights of RobertaModel were not initialized from the model checkpoint at pysentimiento/robertuito-sentiment-analysis and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 528/528 [00:31<00:00, 16.93it/s]


In [None]:
def validate_tms_rnn(subject_embs, labels, net, device):
    net.to(device)
    net.eval()
    predictions = []
    with torch.no_grad():
        batch_data, batch_lens, _ = lstm_collate([(embs, label) for embs, label in zip(subject_embs, labels)])
        batch_data, batch_lens = batch_data.to(device), batch_lens.to(device)
        seq_predictions = net.predict_all_timesteps(batch_data, batch_lens)
        for seq_pred, true_label in zip(seq_predictions, labels):
            idxs = np.nonzero(seq_pred)[0]
            predictions.append(seq_pred[idxs[0]] if len(idxs) > 0 else 0)
    return {
        'acc': metrics.accuracy_score(labels, predictions),
        'macro_f1': metrics.f1_score(labels, predictions, average='macro', zero_division=0),
        'cls_report': metrics.classification_report(labels, predictions, zero_division=0),
        'cfm': metrics.confusion_matrix(labels, predictions)
    }
def train_gdro_rnn_sl(net, optimizer, device, criterion, train_dl, q, soft_labels, eta=0.1):
    net.to(device)
    net.train()
    loss = 0
    num_batches = 0
    preds = []
    labels = []
    for batch_data, batch_lens, batch_labels in train_dl:
        labels.append(batch_labels.numpy())
        unique_batch_labels = np.unique(batch_labels.numpy())
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        batch_lens = batch_lens.to(device)
        optimizer.zero_grad()
        out = net(batch_data, batch_lens)
        batch_losses = nn.functional.cross_entropy(out, soft_labels[batch_labels], reduction='none')
        for cls in unique_batch_labels:
            idx_cls = batch_labels == cls
            q[cls] *= (eta * batch_losses[idx_cls].mean()).exp().item()
        q /= q.sum()
        loss_value = sum(q[cls] * batch_losses[batch_labels == cls].mean() for cls in unique_batch_labels)
        loss_value.backward()
        optimizer.step()
        loss += loss_value.item()
        num_batches += 1
        preds.append(torch.argmax(out, axis=-1).cpu().numpy())
    labels = np.concatenate(labels, axis=0)
    preds = np.concatenate(preds, axis=0)
    return {
        'loss': loss / num_batches,
        'acc': metrics.accuracy_score(labels, preds),
        'macro_f1': metrics.f1_score(labels, preds, average='macro', zero_division=0),
        'cls_report': metrics.classification_report(labels, preds, zero_division=0),
        'cfm': metrics.confusion_matrix(labels, preds)
    }, q

def run_train_gdro_rnn_sl(net, optimizer, criterion, device, train_dl, train_embs, train_labels,
                          val_embs, val_labels, soft_labels, output_dir, max_epochs=150, n_classes=2, eta=0.1):
    best_macro_f1_val = 0
    logs = {'train': defaultdict(list), 'val': defaultdict(list), 'train_eval': defaultdict(list), 'epoch': 0}
    q = torch.ones(n_classes, dtype=torch.float32, device=device) / n_classes
    with tqdm(total=max_epochs, desc="Training", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
        for epoch in range(max_epochs):
            train_report, q = train_gdro_rnn_sl(net, optimizer, device, criterion, train_dl, q, soft_labels, eta)
            logs['train'].update({k: logs['train'][k] + [v] for k, v in train_report.items()})
            val_report = validate_tms_rnn(val_embs, val_labels, net, device)
            logs['val'].update({k: logs['val'][k] + [v] for k, v in val_report.items()})
            pbar.set_postfix({'loss': f"{train_report['loss']:.4f}", 'val_f1': f"{val_report['macro_f1']:.4f}"})
            pbar.update(1)
            if val_report['macro_f1'] >= best_macro_f1_val:
                best_macro_f1_val = val_report['macro_f1']
                torch.save(net.cpu().state_dict(), f'{output_dir}/net_params.pt')
                logs['epoch'] = epoch
                train_report_eval = validate_tms_rnn(train_embs, train_labels, net, device)
                logs['train_eval'].update({k: logs['train_eval'][k] + [v] for k, v in train_report_eval.items()})
    return logs

def format_confusion_matrix(cfm):
    return "\n".join([
        f"{'':>8} {'Pred 0':>8} {'Pred 1':>8}",
        f"{'True 0':>8} {cfm[0,0]:8d} {cfm[0,1]:8d}",
        f"{'True 1':>8} {cfm[1,0]:8d} {cfm[1,1]:8d}"
    ])

ds = EmbDatasetRNNAug(embs, labels, thr_rng=0.6, n_msg=10)
train_ds, test_ds = random_split(ds, [0.8, 0.2], generator=torch.Generator().manual_seed(2909))
train_labels = [labels[i] for i in train_ds.indices]
test_labels = [labels[i] for i in test_ds.indices]
val_embs = [embs[i] for i in test_ds.indices]
val_labels = [labels[i] for i in test_ds.indices]
train_embs = [embs[i] for i in train_ds.indices]
train_labels = [labels[i] for i in train_ds.indices]

save_dir = 'pre_trained_models'
os.makedirs(save_dir, exist_ok=True)
soft_labels = torch.tensor([[0.95, 0.05], [0.05, 0.95]], dtype=torch.float32, device=device)

f1s = []
for h_size in [96, 128]:
    scores = []
    for bs in [2, 4, 8]:
        dir_name = f'processed_data_h_{h_size}_bs_{bs}_0.95_0.05'
        output_dir = os.path.join(save_dir, dir_name)
        os.makedirs(output_dir, exist_ok=True)

        random.seed(2909)
        np.random.seed(2909)
        torch.manual_seed(2909)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=False, collate_fn=lstm_collate)
        net = LSTMClassifier(embs[0].shape[-1], h_size=h_size, output_dim=2)
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
        loss_fn = nn.CrossEntropyLoss()

        logs = run_train_gdro_rnn_sl(net, optimizer, loss_fn, device, train_dl, train_embs, train_labels,
                                     val_embs, val_labels, soft_labels, output_dir, max_epochs=150)

        for k in logs['train_eval']:
            if k not in ['cls_report', 'cfm']:
                fig, ax = make_plot(logs['train'][k], logs['val'][k], k)
                fig.savefig(f'{output_dir}/{k}.png')
                plt.close(fig)
                fig, ax = make_plot(logs['train_eval'][k], logs['val'][k], k)
                fig.savefig(f'{output_dir}/{k}_eval.png')
                plt.close(fig)

        np.save(f'{output_dir}/logs.npy', logs, allow_pickle=True)

        arg = np.argmax(logs['val']['macro_f1'])
        table_data = [
            ["Epoch tốt nhất", arg],
            ["Val Macro F1", f"{logs['val']['macro_f1'][arg]:.4f}"],
            ["Train_eval Macro F1", f"{logs['train_eval']['macro_f1'][-1]:.4f}"]
        ]
        print(f"{Fore.CYAN}=== Kết quả cho h_size: {h_size}, batch_size: {bs} ==={Style.RESET_ALL}")
        print(tabulate(table_data, headers=["Metric", "Value"], tablefmt="pretty", colalign=("left", "right")))
        print(f"\n{Fore.GREEN}Validation Set:{Style.RESET_ALL}")
        print(format_confusion_matrix(logs['val']['cfm'][arg]))
        print(f"\n{logs['val']['cls_report'][arg]}")
        print(f"\n{Fore.YELLOW}Train_eval Set:{Style.RESET_ALL}")
        print(format_confusion_matrix(logs['val']['cfm'][-1]))
        print(f"\n{logs['train_eval']['cls_report'][-1]}")
        print(f"{Fore.CYAN}{'='*50}{Style.RESET_ALL}\n")
        scores.append(logs['val']['macro_f1'][arg])
    f1s.append(scores)

Phân bố nhãn trong tập train:
[282 278]
Phân bố nhãn trong tập test:
[74 66]


Training: 100%|██████████| 150/150 [03:06<00:00,  1.24s/it, loss=0.1489, val_f1=0.9569]


[36m=== Kết quả cho h_size: 96, batch_size: 2 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch tốt nhất      |    117 |
| Val Macro F1        | 0.9713 |
| Train_eval Macro F1 | 0.9982 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        2       64

              precision    recall  f1-score   support

           0       0.97      0.97      0.97        74
           1       0.97      0.97      0.97        66

    accuracy                           0.97       140
   macro avg       0.97      0.97      0.97       140
weighted avg       0.97      0.97      0.97       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        4       62

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278

    acc

Training: 100%|██████████| 150/150 [01:53<00:00,  1.32it/s, loss=0.1862, val_f1=0.9353]


[36m=== Kết quả cho h_size: 96, batch_size: 4 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch tốt nhất      |    142 |
| Val Macro F1        | 0.9498 |
| Train_eval Macro F1 | 1.0000 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       71        3
  True 1        4       62

              precision    recall  f1-score   support

           0       0.95      0.96      0.95        74
           1       0.95      0.94      0.95        66

    accuracy                           0.95       140
   macro avg       0.95      0.95      0.95       140
weighted avg       0.95      0.95      0.95       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       71        3
  True 1        6       60

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278

    acc

Training: 100%|██████████| 150/150 [01:05<00:00,  2.30it/s, loss=0.1989, val_f1=0.8049]


[36m=== Kết quả cho h_size: 96, batch_size: 8 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch tốt nhất      |     49 |
| Val Macro F1        | 0.9000 |
| Train_eval Macro F1 | 0.9821 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       62       12
  True 1        2       64

              precision    recall  f1-score   support

           0       0.97      0.84      0.90        74
           1       0.84      0.97      0.90        66

    accuracy                           0.90       140
   macro avg       0.91      0.90      0.90       140
weighted avg       0.91      0.90      0.90       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       49       25
  True 1        2       64

              precision    recall  f1-score   support

           0       1.00      0.96      0.98       282
           1       0.97      1.00      0.98       278

    acc

Training: 100%|██████████| 150/150 [03:45<00:00,  1.51s/it, loss=0.1474, val_f1=0.9071]


[36m=== Kết quả cho h_size: 128, batch_size: 2 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch tốt nhất      |     68 |
| Val Macro F1        | 0.9569 |
| Train_eval Macro F1 | 1.0000 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        4       62

              precision    recall  f1-score   support

           0       0.95      0.97      0.96        74
           1       0.97      0.94      0.95        66

    accuracy                           0.96       140
   macro avg       0.96      0.96      0.96       140
weighted avg       0.96      0.96      0.96       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       65        9
  True 1        4       62

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278

    ac

Training: 100%|██████████| 150/150 [02:29<00:00,  1.00it/s, loss=0.1878, val_f1=0.9284]


[36m=== Kết quả cho h_size: 128, batch_size: 4 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch tốt nhất      |    130 |
| Val Macro F1        | 0.9569 |
| Train_eval Macro F1 | 1.0000 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        4       62

              precision    recall  f1-score   support

           0       0.95      0.97      0.96        74
           1       0.97      0.94      0.95        66

    accuracy                           0.96       140
   macro avg       0.96      0.96      0.96       140
weighted avg       0.96      0.96      0.96       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       68        6
  True 1        4       62

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278

    ac

Training: 100%|██████████| 150/150 [01:51<00:00,  1.35it/s, loss=0.1958, val_f1=0.8712]

[36m=== Kết quả cho h_size: 128, batch_size: 8 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch tốt nhất      |     30 |
| Val Macro F1        | 0.9071 |
| Train_eval Macro F1 | 0.9929 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       65        9
  True 1        4       62

              precision    recall  f1-score   support

           0       0.94      0.88      0.91        74
           1       0.87      0.94      0.91        66

    accuracy                           0.91       140
   macro avg       0.91      0.91      0.91       140
weighted avg       0.91      0.91      0.91       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       58       16
  True 1        2       64

              precision    recall  f1-score   support

           0       0.99      0.99      0.99       282
           1       0.99      0.99      0.99       278

    ac


