In [None]:
import os
import json
from collections import defaultdict
import emoji
import re
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence
from sklearn import metrics
import matplotlib.pyplot as plt
from tqdm import tqdm
from tabulate import tabulate
from colorama import init, Fore, Style

In [None]:
data_dir = "data/task1/train/subjects"
subjects = defaultdict(list)

def extract_emojis(text):
    return set(emoji_data['emoji'] for emoji_data in emoji.emoji_list(text))

all_emojis = set()
for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        with open(os.path.join(data_dir, filename), "r") as f:
            messages = json.load(f)
            nick = filename.split(".")[0]
            for msg in messages:
                message_text = str(msg["message"]) if msg["message"] is not None else ""
                all_emojis.update(extract_emojis(message_text))

T·∫•t c·∫£ c√°c bi·ªÉu t∆∞·ª£ng c·∫£m x√∫c trong d·ªØ li·ªáu:
ü´Ä - :anatomical_heart:
üò± - :face_screaming_in_fear:
üáÆüá≥ - :India:
üéâ - :party_popper:
üê¨ - :dolphin:
ü•≥ - :partying_face:
üòâ - :winking_face:
‚õî - :no_entry:
üöÄ - :rocket:
üò∂ - :face_without_mouth:
üëåüèº - :OK_hand_medium-light_skin_tone:
üëÜ - :backhand_index_pointing_up:
üò† - :angry_face:
üåö - :new_moon_face:
üòï - :confused_face:
ü§¶üèª‚Äç‚ôÇ - :man_facepalming_light_skin_tone:
üëÜüèΩ - :backhand_index_pointing_up_medium_skin_tone:
üêª - :bear:
üò¥ - :sleeping_face:
üí∂ - :euro_banknote:
üòî - :pensive_face:
üôà - :see-no-evil_monkey:
ü•∫ - :pleading_face:
üò™ - :sleepy_face:
ü¶ñ - :T-Rex:
üí∞ - :money_bag:
ü•ö - :egg:
üçé - :red_apple:
‚úÖ - :check_mark_button:
üòÆ - :face_with_open_mouth:
‚òπ - :frowning_face:
üëè - :clapping_hands:
‚ô• - :heart_suit:
üò® - :fearful_face:
ü§¶üèª‚Äç‚ôÄ - :woman_facepalming_light_skin_tone:
üî™ - :kitchen_knife:
‚ùó - :red_exclamation_

In [None]:
def map_emoji_to_spanish(emoji=None):
    emoji_map = {
        "üîù": "arriba",
        "üëé": "no me gusta",
        "üò≥": "sorprendido",
        "4Ô∏è‚É£": "cuatro",
        "üñêüèº": "mano abierta",
        "üíé": "diamante",
        "ü§£": "riendo fuerte",
        "ü§ûüèª": "dedos cruzados",
        "üç∫": "cerveza",
        "‚ù£": "coraz√≥n exclamaci√≥n",
        "ü§°": "payaso",
        "üéÖüèª": "Pap√° Noel",
        "‚¨Ü": "subir",
        "üí∏": "dinero volando",
        "ü§§": "babeando",
        "‚ùå": "cruz",
        "üôåüèª": "manos arriba",
        "ü§©": "asombrado",
        "üáµüá™": "Per√∫",
        "ü§†": "vaquero",
        "üü£": "c√≠rculo morado",
        "üñêüèΩ": "mano abierta",
        "üôÉ": "cara invertida",
        "üê∏": "rana",
        "üëÜüèº": "se√±alando arriba",
        "üàö": "gratis",
        "üåê": "mundo",
        "üéÅ": "regalo",
        "üéâ": "celebraci√≥n",
        "üòµ‚Äçüí´": "mareado",
        "üåù": "luna llena",
        "üôã‚Äç‚ôÇ": "hombre levantando mano",
        "3Ô∏è‚É£": "tres",
        "üîÆ": "bola de cristal",
        "üò∞": "nervioso",
        "üò®": "miedo",
        "‚ùì": "pregunta",
        "‚òùüèª": "dedo arriba",
        "ü•≤": "l√°grimas de alegr√≠a",
        "‚úäüèº": "pu√±o levantado",
        "‚úä": "pu√±o",
        "üßòüèª‚Äç‚ôÇ": "meditaci√≥n",
        "üßê": "curioso",
        "üëèüèæ": "aplausos",
        "üê≥": "ballena",
        "üí™üèº": "fuerza",
        "‚úÖ": "aprobado",
        "ü§¶üèº‚Äç‚ôÇ": "verg√ºenza",
        "üòç": "enamorado",
        "üëª": "fantasma",
        "üòÇ": "riendo",
        "üí™üèª": "fuerte",
        "ü´§": "decepci√≥n",
        "‚öΩ": "f√∫tbol",
        "ü•ö": "huevo",
        "üôè": "rezando",
        "ü§ô": "ll√°mame",
        "üôÑ": "aburrido",
        "üò≤": "asombro",
        "‚ô•": "coraz√≥n",
        "üçé": "manzana",
        "üêª": "oso",
        "ü§™": "loco",
        "üëÜüèΩ": "se√±alando arriba",
        "üé¢": "monta√±a rusa",
        "üôå": "celebrando",
        "üåò": "luna menguante",
        "ü´°": "saludo",
        "üôãüèª‚Äç‚ôÄ": "mujer levantando mano",
        "ü§¶‚Äç‚ôÇ": "error",
        "üåä": "ola",
        "üòâ": "gui√±o",
        "ü•∂": "fr√≠o",
        "üíã": "beso",
        "üá∫üá¶": "Ucrania",
        "üò∂‚Äçüå´": "confundido",
        "üå¨": "viento",
        "üí©": "mierda",
        "üëåüèº": "perfecto",
        "üôÜ‚Äç‚ôÇ": "hombre OK",
        "üí™üèΩ": "fuerza",
        "üò±": "gritando",
        "1Ô∏è‚É£": "uno",
        "ü§ò": "rock",
        "üëâ": "se√±alando derecha",
        "üôÇ": "sonriendo",
        "üëÅ": "ojo",
        "üëÄ": "ojos",
        "üî•": "fuego",
        "‚è∫": "grabar",
        "üòÖ": "sudando",
        "‚ùó": "exclamaci√≥n",
        "üòï": "confuso",
        "ü•í": "pepino",
        "üéÇ": "torta",
        "üò•": "aliviado",
        "‚úåüèΩ": "victoria",
        "üéæ": "tenis",
        "üíö": "coraz√≥n verde",
        "üíî": "coraz√≥n roto",
        "üëç": "bien",
        "üê∂": "perro",
        "‚úî": "verificado",
        "‚úåüèª": "paz",
        "üí™": "m√∫sculo",
        "üéà": "globo",
        "ü§ë": "dinero en la cara",
        "üòæ": "gato enfadado",
        "üíµ": "billete",
        "üëãüèª": "saludando",
        "üëàüèª": "se√±alando izquierda",
        "üí∞": "bolsa de dinero",
        "üéº": "m√∫sica",
        "üêÆ": "vaca",
        "üá¶üá∑": "Argentina",
        "ü§∑üèº‚Äç‚ôÄ": "mujer encogi√©ndose",
        "üíÉ": "bailando",
        "ü§Æ": "vomitando",
        "üá∑üá∫": "Rusia",
        "üòé": "genial",
        "ü•≥": "fiesta",
        "‚ö∞": "ata√∫d",
        "üíØ": "cien puntos",
        "üìà": "gr√°fico subiendo",
        "üò≠": "llorando",
        "üò™": "somnoliento",
        "ü§ûüèº": "suerte",
        "ü§¶üèΩ‚Äç‚ôÇ": "hombre avergonzado",
        "‚ñ∂": "reproducir",
        "‚õî": "prohibido",
        "üé∂": "notas musicales",
        "üôä": "mono callado",
        "üåö": "luna nueva",
        "üëè": "aplaudiendo",
        "üôèüèΩ": "rezando",
        "üòÑ": "feliz",
        "ü§¶üèª‚Äç‚ôÇ": "error hombre",
        "üá®üá≥": "China",
        "üëåüèª": "OK",
        "ü§ôüèª": "ll√°mame",
        "üá≥üá¨": "Nigeria",
        "üòÉ": "alegre",
        "‚ÑπÔ∏è": "informaci√≥n",
        "üó£": "hablando",
        "üôåüèº": "manos levantadas",
        "ü§û": "cruzando dedos",
        "üòú": "broma",
        "üéµ": "nota musical",
        "ü§ü": "te amo",
        "‚úà": "avi√≥n",
        "üëåüèΩ": "perfecto",
        "ü§¶üèΩ": "verg√ºenza",
        "üëçüèæ": "bien",
        "üîπ": "diamante azul",
        "üòù": "lengua fuera",
        "üí∂": "euro",
        "ü§ì": "nerd",
        "üò∂": "sin expresi√≥n",
        "üêÅ": "rat√≥n",
        "üêó": "jabal√≠",
        "ü§¶üèª‚Äç‚ôÄ": "mujer avergonzada",
        "üçè": "manzana verde",
        "üü¢": "c√≠rculo verde",
        "üôåüèΩ": "celebraci√≥n",
        "üá™üá∏": "Espa√±a",
        "‚ú®": "brillo",
        "ü§∑üèª‚Äç‚ôÇ": "hombre encogi√©ndose",
        "üö®": "alarma",
        "ü•∞": "amor",
        "‚ò∫": "sonrisa",
        "ü§∑‚Äç‚ôÇ": "duda",
        "ü§Ø": "cabeza explotando",
        "ü•∫": "suplicando",
        "üêü": "pez",
        "üáÆüá≥": "India",
        "üòê": "neutral",
        "üòÅ": "sonriendo amplio",
        "üôãüèª‚Äç‚ôÇ": "levantando mano",
        "üòì": "sudor",
        "üï∫": "bailando",
        "üòØ": "sorprendido",
        "üëâüèª": "se√±alando derecha",
        "üí•": "explosi√≥n",
        "üò¢": "llorando",
        "ü¶ñ": "T-Rex",
        "‚ö°": "rayo",
        "üò¥": "durmiendo",
        "ü´£": "espiando",
        "üòª": "gato enamorado",
        "ü•µ": "caliente",
        "üëçüèª": "pulgar arriba",
        "üáßüáæ": "Bielorrusia",
        "ü§∑üèΩ‚Äç‚ôÄ": "mujer dudando",
        "üòã": "saboreando",
        "üö´": "prohibido",
        "üëÖ": "lengua",
        "üòÜ": "riendo mucho",
        "üòä": "sonriendo feliz",
        "üòá": "√°ngel",
        "üò†": "enojado",
        "üåé": "Am√©ricas",
        "‚¨á": "bajar",
        "üòû": "triste",
        "üîµ": "c√≠rculo azul",
        "üì®": "correo",
        "üëÜ": "arriba",
        "üòò": "besando",
        "üåñ": "luna gibosa",
        "‚ù§": "coraz√≥n rojo",
        "‚òù": "dedo arriba",
        "‚úå": "victoria",
        "üçª": "brindis",
        "ü§ù": "apret√≥n de manos",
        "üëã": "saludo",
        "üí≤": "d√≥lar",
        "üëçüèº": "bien",
        "üö∂üèª‚Äç‚ôÇ": "hombre caminando",
        "ü§î": "pensando",
        "üòπ": "gato riendo",
        "ü´µ": "se√±alando",
        "ü§≠": "riendo callado",
        "ü™Ç": "paraca√≠das",
        "üòà": "diablo",
        "üî∞": "principiante",
        "ü´Ä": "coraz√≥n",
        "üòí": "molesto",
        "ü§∑": "no s√©",
        "üòÄ": "felicidad",
        "üçÄ": "tr√©bol",
        "üî™": "cuchillo",
        "üòÆ": "boca abierta",
        "üí¨": "hablar",
        "‚úã": "mano levantada",
        "üòå": "alivio",
        "üí¶": "sudor",
        "ü§∑üèº‚Äç‚ôÇ": "duda",
        "‚òπ": "tristeza",
        "ü§®": "sospecha",
        "ü§ôüèΩ": "ll√°mame",
        "üîª": "tri√°ngulo abajo",
        "üõç": "compras",
        "ü§ß": "estornudo",
        "üí´": "mareo",
        "üëº": "√°ngel",
        "ü§å": "pellizco",
        "üí®": "r√°pido",
        "üòõ": "lengua fuera",
        "üéÑ": "√°rbol de Navidad",
        "ü•π": "l√°grimas contenidas",
        "‚òÄ": "sol",
        "üåï": "luna llena",
        "üá∫üá∏": "Estados Unidos",
        "üëèüèº": "aplausos",
        "‚Äº": "doble exclamaci√≥n",
        "üöÄ": "cohete",
        "üò°": "furioso",
        "üò¨": "nervios",
        "üî¥": "c√≠rculo rojo",
        "üôèüèª": "orando",
        "üôà": "mono tap√°ndose",
        "ü¶•": "perezoso",
        "üåô": "luna creciente",
        "üëà": "se√±alando izquierda",
        "üê∑": "cerdo",
        "ü•∏": "disfrazado",
        "üòè": "sonrisa p√≠cara",
        "üòö": "beso cerrado",
        "‚öì": "ancla",
        "üëå": "OK",
        "ü§üüèª": "te amo",
        "üåå": "v√≠a l√°ctea",
        "‚ö†": "advertencia",
        "ü•±": "bostezando",
        "üê¨": "delf√≠n",
        "üìä": "gr√°fico",
        "üêÄ": "rata",
        "ü§ó": "abrazo",
        "üòî": "pensativo",
        "üëèüèª": "aplaudiendo",
        "üáßüá¨": "Bulgaria",
        "ü•¥": "mareado"
    }
    if emoji is None:
        return emoji_map  # Tr·∫£ v·ªÅ to√†n b·ªô t·ª´ ƒëi·ªÉn n·∫øu kh√¥ng truy·ªÅn emoji
    return emoji_map.get(emoji, emoji)  # Tr·∫£ v·ªÅ √°nh x·∫° ho·∫∑c emoji g·ªëc

def replace_emojis_in_text(text):
    result = text
    for emoji, spanish_text in map_emoji_to_spanish().items():
        result = result.replace(emoji, f" {spanish_text} ")
    words = result.split()
    if not words:
        return ""
    cleaned_words = [words[0]]
    for i in range(1, len(words)):
        if words[i] != words[i-1]:
            cleaned_words.append(words[i])
    return " ".join(cleaned_words).strip()

for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        with open(os.path.join(data_dir, filename), "r") as f:
            messages = json.load(f)
            nick = filename.split(".")[0]
            subjects[nick] = [
                replace_emojis_in_text(str(msg["message"]) if msg["message"] is not None else "")
                for msg in messages
            ]

In [6]:
subjects['user10343']

['Voy cargando la escopeta pa longuear alguna monedilla',
 'Todav√≠a no',
 'Esto est√° m√°s aburrido',
 'Falta mucho para que entren los chinos ?',
 'Por se√±ales de Facu ?',
 'Compart√≠ ahora bien',
 'Y compartilos ahora . Igual te van a banear cuando lean esto riendo',
 'Sacate la papa de la boca',
 'Se pic√≥ la clande',
 'riendo',
 'Screenshot ( 6 mar . 2022 20:03 : 38 )',
 'La paciencia sudando',
 'Divino ! aplaudiendo',
 'Screenshot ( 6 mar . 2022 20:22 : 39 )',
 'Cerrada lunitaa enamorado',
 'mono tap√°ndose',
 'Un short en 38k es lo mismo que ir al casino riendo',
 'Recuperar que ?',
 'Si no entra una ballena generosa en los pr√≥ximos minutos btc se pega un palo',
 'Vamoooo ma√±ana se come enamorado',
 'De proyectos no se absolutamente nada sonriendo feliz solo tradeo cualquier cosa que se mueva riendo',
 'riendo aplaudiendo',
 '4 k en un trade ? Mierda ! Que haces ac√° hermano',
 'Anoche no estaban meta long todos ?',
 'Yo vi que subian longs a dos manos .. pero no recuerdo si 

In [None]:
task1_labels = {}
with open("data/task1/train/gold_task1.txt", "r") as f:
    next(f)
    for line in f:
        nick, risk = line.strip().split(",")
        task1_labels[nick] = int(risk)

messages = []
labels = []
for nick, subject_messages in subjects.items():
    if nick in task1_labels:
        if task1_labels[nick] == 0:
            msg_len = len(subject_messages)
            if msg_len > 1:
                split_point = msg_len // 2
                messages.append(subject_messages[:split_point])
                messages.append(subject_messages[split_point:])
                labels.extend([0, 0])
            else:
                messages.append(subject_messages)
                labels.append(0)
        else:
            messages.append(subject_messages)
            labels.append(task1_labels[nick])

T·ªïng s·ªë m·∫´u: 528, Ph√¢n b·ªë nh√£n: [356 172]


In [None]:
class EmbDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


In [None]:
class EmbDatasetRNNAug(Dataset):
    def __init__(self, embeddings, labels, thr_rng=0.7, n_msg=10):
        self.embeddings = embeddings
        self.labels = labels
        self.emb0 = [embeddings[i] for i in range(len(labels)) if labels[i] == 0]
        self.thr_rng = thr_rng
        self.n_msg = n_msg

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.labels[idx] == 0:
            nr_msg = np.random.randint(1, len(self.embeddings[idx]) + 1)
            return self.embeddings[idx][:nr_msg], self.labels[idx]
        else:
            rnd = np.random.uniform()
            if rnd > self.thr_rng:
                neutral = self.emb0[np.random.randint(0, len(self.emb0))]
                n_extra = np.random.randint(1, min(len(neutral), self.n_msg))
                return np.concatenate([neutral[:n_extra], self.embeddings[idx]], axis=0), self.labels[idx]
            return self.embeddings[idx], self.labels[idx]

In [None]:
def make_plot(train_scores, val_scores, y_label, figsize=(8,5)):
    fig, ax = plt.subplots(1,1,figsize=figsize)
    ax.plot(train_scores, label='Train')
    ax.plot(val_scores, label='Val')
    ax.set_xlabel('Epoch')
    ax.set_ylabel(y_label)
    ax.legend()
    return fig, ax

In [None]:
def get_cls_embeddings(all_messages, model, tokenizer, device, m_length=96):
    model.to(device)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for subject_messages in tqdm(all_messages):
            input = tokenizer(subject_messages, padding=True, truncation=True, max_length=m_length, return_tensors='pt')
            output = model(**input.to(device))
            embeddings.append(output.last_hidden_state[:, 0, :].cpu().numpy())
    return embeddings

In [None]:
def lstm_collate(batch):
    labels = [x[1] for x in batch]
    labels = torch.tensor(labels, dtype=torch.long)
    data = [torch.tensor(x[0], dtype=torch.float32) for x in batch]
    batch_data = pad_sequence(data)
    lens = torch.tensor([len(x) for x in data], dtype=torch.long).unsqueeze(0).unsqueeze(-1)
    lens -= 1
    return batch_data, lens, labels


In [None]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, h_size, output_dim, dropout=0):
        super().__init__()
        self.lstm = nn.LSTM(input_size, h_size, num_layers=1, batch_first=False, dropout=dropout, bidirectional=True)
        self.attention = nn.Linear(2 * h_size, 1)
        self.classifier = nn.Linear(2 * h_size, output_dim)

    def forward(self, seq_data, seq_lens, state=None):
        lstm_out, _ = self.lstm(seq_data)
        attention_weights = torch.softmax(self.attention(lstm_out), dim=0)
        context_vector = torch.sum(attention_weights * lstm_out, dim=0)
        return self.classifier(context_vector)

    def predict_all_timesteps(self, seq_data, seq_lens, state=None):
        lstm_out, _ = self.lstm(seq_data)
        attention_weights = torch.softmax(self.attention(lstm_out), dim=0)
        logits_all = self.classifier(lstm_out)
        pred_all = torch.argmax(logits_all, dim=2)
        ts_predictions = [pred_all[:seq_lens[0, i].item(), i].squeeze().cpu().numpy() for i in range(pred_all.shape[1])]
        return ts_predictions

In [13]:
# Thi·∫øt l·∫≠p v√† hu·∫•n luy·ªán
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
name = 'pysentimiento/robertuito-sentiment-analysis'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModel.from_pretrained(name)
embs = get_cls_embeddings(messages, model, tokenizer, device, m_length=96)

Some weights of RobertaModel were not initialized from the model checkpoint at pysentimiento/robertuito-sentiment-analysis and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 528/528 [00:31<00:00, 16.93it/s]


In [None]:
def validate_tms_rnn(subject_embs, labels, net, device):
    net.to(device)
    net.eval()
    predictions = []
    with torch.no_grad():
        batch_data, batch_lens, _ = lstm_collate([(embs, label) for embs, label in zip(subject_embs, labels)])
        batch_data, batch_lens = batch_data.to(device), batch_lens.to(device)
        seq_predictions = net.predict_all_timesteps(batch_data, batch_lens)
        for seq_pred, true_label in zip(seq_predictions, labels):
            idxs = np.nonzero(seq_pred)[0]
            predictions.append(seq_pred[idxs[0]] if len(idxs) > 0 else 0)
    return {
        'acc': metrics.accuracy_score(labels, predictions),
        'macro_f1': metrics.f1_score(labels, predictions, average='macro', zero_division=0),
        'cls_report': metrics.classification_report(labels, predictions, zero_division=0),
        'cfm': metrics.confusion_matrix(labels, predictions)
    }
def train_gdro_rnn_sl(net, optimizer, device, criterion, train_dl, q, soft_labels, eta=0.1):
    net.to(device)
    net.train()
    loss = 0
    num_batches = 0
    preds = []
    labels = []
    for batch_data, batch_lens, batch_labels in train_dl:
        labels.append(batch_labels.numpy())
        unique_batch_labels = np.unique(batch_labels.numpy())
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        batch_lens = batch_lens.to(device)
        optimizer.zero_grad()
        out = net(batch_data, batch_lens)
        batch_losses = nn.functional.cross_entropy(out, soft_labels[batch_labels], reduction='none')
        for cls in unique_batch_labels:
            idx_cls = batch_labels == cls
            q[cls] *= (eta * batch_losses[idx_cls].mean()).exp().item()
        q /= q.sum()
        loss_value = sum(q[cls] * batch_losses[batch_labels == cls].mean() for cls in unique_batch_labels)
        loss_value.backward()
        optimizer.step()
        loss += loss_value.item()
        num_batches += 1
        preds.append(torch.argmax(out, axis=-1).cpu().numpy())
    labels = np.concatenate(labels, axis=0)
    preds = np.concatenate(preds, axis=0)
    return {
        'loss': loss / num_batches,
        'acc': metrics.accuracy_score(labels, preds),
        'macro_f1': metrics.f1_score(labels, preds, average='macro', zero_division=0),
        'cls_report': metrics.classification_report(labels, preds, zero_division=0),
        'cfm': metrics.confusion_matrix(labels, preds)
    }, q

def run_train_gdro_rnn_sl(net, optimizer, criterion, device, train_dl, train_embs, train_labels,
                          val_embs, val_labels, soft_labels, output_dir, max_epochs=150, n_classes=2, eta=0.1):
    best_macro_f1_val = 0
    logs = {'train': defaultdict(list), 'val': defaultdict(list), 'train_eval': defaultdict(list), 'epoch': 0}
    q = torch.ones(n_classes, dtype=torch.float32, device=device) / n_classes
    with tqdm(total=max_epochs, desc="Training", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
        for epoch in range(max_epochs):
            train_report, q = train_gdro_rnn_sl(net, optimizer, device, criterion, train_dl, q, soft_labels, eta)
            logs['train'].update({k: logs['train'][k] + [v] for k, v in train_report.items()})
            val_report = validate_tms_rnn(val_embs, val_labels, net, device)
            logs['val'].update({k: logs['val'][k] + [v] for k, v in val_report.items()})
            pbar.set_postfix({'loss': f"{train_report['loss']:.4f}", 'val_f1': f"{val_report['macro_f1']:.4f}"})
            pbar.update(1)
            if val_report['macro_f1'] >= best_macro_f1_val:
                best_macro_f1_val = val_report['macro_f1']
                torch.save(net.cpu().state_dict(), f'{output_dir}/net_params.pt')
                logs['epoch'] = epoch
                train_report_eval = validate_tms_rnn(train_embs, train_labels, net, device)
                logs['train_eval'].update({k: logs['train_eval'][k] + [v] for k, v in train_report_eval.items()})
    return logs

def format_confusion_matrix(cfm):
    return "\n".join([
        f"{'':>8} {'Pred 0':>8} {'Pred 1':>8}",
        f"{'True 0':>8} {cfm[0,0]:8d} {cfm[0,1]:8d}",
        f"{'True 1':>8} {cfm[1,0]:8d} {cfm[1,1]:8d}"
    ])

ds = EmbDatasetRNNAug(embs, labels, thr_rng=0.6, n_msg=10)
train_ds, test_ds = random_split(ds, [0.8, 0.2], generator=torch.Generator().manual_seed(2909))
train_labels = [labels[i] for i in train_ds.indices]
test_labels = [labels[i] for i in test_ds.indices]
val_embs = [embs[i] for i in test_ds.indices]
val_labels = [labels[i] for i in test_ds.indices]
train_embs = [embs[i] for i in train_ds.indices]
train_labels = [labels[i] for i in train_ds.indices]

save_dir = 'pre_trained_models'
os.makedirs(save_dir, exist_ok=True)
soft_labels = torch.tensor([[0.95, 0.05], [0.05, 0.95]], dtype=torch.float32, device=device)

f1s = []
for h_size in [96, 128]:
    scores = []
    for bs in [2, 4, 8]:
        dir_name = f'processed_data_h_{h_size}_bs_{bs}_0.95_0.05'
        output_dir = os.path.join(save_dir, dir_name)
        os.makedirs(output_dir, exist_ok=True)

        random.seed(2909)
        np.random.seed(2909)
        torch.manual_seed(2909)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=False, collate_fn=lstm_collate)
        net = LSTMClassifier(embs[0].shape[-1], h_size=h_size, output_dim=2)
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
        loss_fn = nn.CrossEntropyLoss()

        logs = run_train_gdro_rnn_sl(net, optimizer, loss_fn, device, train_dl, train_embs, train_labels,
                                     val_embs, val_labels, soft_labels, output_dir, max_epochs=150)

        for k in logs['train_eval']:
            if k not in ['cls_report', 'cfm']:
                fig, ax = make_plot(logs['train'][k], logs['val'][k], k)
                fig.savefig(f'{output_dir}/{k}.png')
                plt.close(fig)
                fig, ax = make_plot(logs['train_eval'][k], logs['val'][k], k)
                fig.savefig(f'{output_dir}/{k}_eval.png')
                plt.close(fig)

        np.save(f'{output_dir}/logs.npy', logs, allow_pickle=True)

        arg = np.argmax(logs['val']['macro_f1'])
        table_data = [
            ["Epoch t·ªët nh·∫•t", arg],
            ["Val Macro F1", f"{logs['val']['macro_f1'][arg]:.4f}"],
            ["Train_eval Macro F1", f"{logs['train_eval']['macro_f1'][-1]:.4f}"]
        ]
        print(f"{Fore.CYAN}=== K·∫øt qu·∫£ cho h_size: {h_size}, batch_size: {bs} ==={Style.RESET_ALL}")
        print(tabulate(table_data, headers=["Metric", "Value"], tablefmt="pretty", colalign=("left", "right")))
        print(f"\n{Fore.GREEN}Validation Set:{Style.RESET_ALL}")
        print(format_confusion_matrix(logs['val']['cfm'][arg]))
        print(f"\n{logs['val']['cls_report'][arg]}")
        print(f"\n{Fore.YELLOW}Train_eval Set:{Style.RESET_ALL}")
        print(format_confusion_matrix(logs['val']['cfm'][-1]))
        print(f"\n{logs['train_eval']['cls_report'][-1]}")
        print(f"{Fore.CYAN}{'='*50}{Style.RESET_ALL}\n")
        scores.append(logs['val']['macro_f1'][arg])
    f1s.append(scores)

Ph√¢n b·ªë nh√£n trong t·∫≠p train:
[282 278]
Ph√¢n b·ªë nh√£n trong t·∫≠p test:
[74 66]


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [03:06<00:00,  1.24s/it, loss=0.1489, val_f1=0.9569]


[36m=== K·∫øt qu·∫£ cho h_size: 96, batch_size: 2 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch t·ªët nh·∫•t      |    117 |
| Val Macro F1        | 0.9713 |
| Train_eval Macro F1 | 0.9982 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        2       64

              precision    recall  f1-score   support

           0       0.97      0.97      0.97        74
           1       0.97      0.97      0.97        66

    accuracy                           0.97       140
   macro avg       0.97      0.97      0.97       140
weighted avg       0.97      0.97      0.97       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        4       62

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [01:53<00:00,  1.32it/s, loss=0.1862, val_f1=0.9353]


[36m=== K·∫øt qu·∫£ cho h_size: 96, batch_size: 4 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch t·ªët nh·∫•t      |    142 |
| Val Macro F1        | 0.9498 |
| Train_eval Macro F1 | 1.0000 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       71        3
  True 1        4       62

              precision    recall  f1-score   support

           0       0.95      0.96      0.95        74
           1       0.95      0.94      0.95        66

    accuracy                           0.95       140
   macro avg       0.95      0.95      0.95       140
weighted avg       0.95      0.95      0.95       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       71        3
  True 1        6       60

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [01:05<00:00,  2.30it/s, loss=0.1989, val_f1=0.8049]


[36m=== K·∫øt qu·∫£ cho h_size: 96, batch_size: 8 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch t·ªët nh·∫•t      |     49 |
| Val Macro F1        | 0.9000 |
| Train_eval Macro F1 | 0.9821 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       62       12
  True 1        2       64

              precision    recall  f1-score   support

           0       0.97      0.84      0.90        74
           1       0.84      0.97      0.90        66

    accuracy                           0.90       140
   macro avg       0.91      0.90      0.90       140
weighted avg       0.91      0.90      0.90       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       49       25
  True 1        2       64

              precision    recall  f1-score   support

           0       1.00      0.96      0.98       282
           1       0.97      1.00      0.98       278


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [03:45<00:00,  1.51s/it, loss=0.1474, val_f1=0.9071]


[36m=== K·∫øt qu·∫£ cho h_size: 128, batch_size: 2 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch t·ªët nh·∫•t      |     68 |
| Val Macro F1        | 0.9569 |
| Train_eval Macro F1 | 1.0000 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        4       62

              precision    recall  f1-score   support

           0       0.95      0.97      0.96        74
           1       0.97      0.94      0.95        66

    accuracy                           0.96       140
   macro avg       0.96      0.96      0.96       140
weighted avg       0.96      0.96      0.96       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       65        9
  True 1        4       62

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278

Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [02:29<00:00,  1.00it/s, loss=0.1878, val_f1=0.9284]


[36m=== K·∫øt qu·∫£ cho h_size: 128, batch_size: 4 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch t·ªët nh·∫•t      |    130 |
| Val Macro F1        | 0.9569 |
| Train_eval Macro F1 | 1.0000 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       72        2
  True 1        4       62

              precision    recall  f1-score   support

           0       0.95      0.97      0.96        74
           1       0.97      0.94      0.95        66

    accuracy                           0.96       140
   macro avg       0.96      0.96      0.96       140
weighted avg       0.96      0.96      0.96       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       68        6
  True 1        4       62

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       282
           1       1.00      1.00      1.00       278

Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 150/150 [01:51<00:00,  1.35it/s, loss=0.1958, val_f1=0.8712]

[36m=== K·∫øt qu·∫£ cho h_size: 128, batch_size: 8 ===[0m
+---------------------+--------+
| Metric              |  Value |
+---------------------+--------+
| Epoch t·ªët nh·∫•t      |     30 |
| Val Macro F1        | 0.9071 |
| Train_eval Macro F1 | 0.9929 |
+---------------------+--------+

[32mValidation Set:[0m
           Pred 0   Pred 1
  True 0       65        9
  True 1        4       62

              precision    recall  f1-score   support

           0       0.94      0.88      0.91        74
           1       0.87      0.94      0.91        66

    accuracy                           0.91       140
   macro avg       0.91      0.91      0.91       140
weighted avg       0.91      0.91      0.91       140


[33mTrain_eval Set:[0m
           Pred 0   Pred 1
  True 0       58       16
  True 1        2       64

              precision    recall  f1-score   support

           0       0.99      0.99      0.99       282
           1       0.99      0.99      0.99       278


