**0. Import Lib**

In [169]:
import re
import spacy
import pandas as pd
import random
import numpy as np
import unicodedata
from dataset import UITVSFCDataset

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import os

In [170]:
def set_seed(seed):
    random.seed(seed)                 
    np.random.seed(seed)              
    torch.manual_seed(seed)           
    torch.cuda.manual_seed(seed)      
    torch.cuda.manual_seed_all(seed)  

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1410)

**1. Import Data**

In [171]:
DATASET_PATH = 'data/UIT-VSFC'
VI_SYNONYM_PATH = 'data/synonym_vi.csv'
VI_STOPWORD_PATH = 'data/vietnamese-stopwords.txt'

In [172]:
dataset = UITVSFCDataset(root_dir=DATASET_PATH).load_data()
synonym_file = pd.read_csv(VI_SYNONYM_PATH)

with open(VI_STOPWORD_PATH,'r',encoding='utf-8') as f:
    lines = f.readlines()
    
vi_stopwords = [line.rstrip('\n') for line in lines]

synonym_dict = {}
for index, row in synonym_file.iterrows():
    word = row['word'].strip()
    synonyms = [s.strip() for s in row['synonyms'].split(';')]
    synonym_dict[word] = synonyms
 
sentiment_labels = {
    0: 'negative',
    1: 'neutral',
    2: 'positive'
}

topic_labels = {
    0: 'lecturer',
    1: 'training program',
    2: 'facility',
    3: 'other'
}

emoji_labels = {
    ':))': 'colonsmilesmile',
    ':)':  'colonsmile',
	':(':  'colonsad',
	'@@':  'colonsurprise',
	'<3':  'colonlove',
	':d':  'colonsmilesmile',
	':3':  'coloncontemn',
	':v':  'colonbigsmile',
	':_':  'coloncc',
	':p':  'colonsmallsmile',
	'>>':  'coloncolon',
	':">': 'colonlovelove',
	'^^': 'colonhihi',
	':': 'doubledot',
	':(': 'colonsadcolon',
	':’(': 'colonsadcolon',
	':@': 'colondoublesurprise',
	'v.v': 'vdotv',
	'...': 'dotdotdot',
	'/': 'fraction',
	'c#': 'cshrap'
}

In [173]:
synonym_dict

{'đẹp': ['xinh', 'tuyệt đẹp', 'dễ thương', 'đẹp đẽ'],
 'tốt': ['ổn', 'xuất sắc', 'tuyệt vời', 'chất lượng'],
 'xấu': ['tệ', 'kém', 'dở', 'kém chất lượng'],
 'nhanh': ['mau', 'lẹ', 'tốc độ', 'nhanh chóng'],
 'chậm': ['từ tốn', 'ì ạch', 'không nhanh'],
 'vui': ['hạnh phúc', 'phấn khởi', 'vui vẻ'],
 'buồn': ['chán', 'rầu rĩ', 'buồn bã'],
 'giận': ['tức', 'cáu', 'nổi nóng'],
 'mệt': ['đuối', 'kiệt sức', 'mệt mỏi'],
 'khó': ['phức tạp', 'gian nan', 'không dễ'],
 'dễ': ['đơn giản', 'dễ dàng', 'thuận tiện'],
 'to': ['lớn', 'bự', 'khổng lồ'],
 'nhỏ': ['bé', 'tí hon', 'nhỏ nhắn'],
 'sai': ['không đúng', 'nhầm', 'lệch'],
 'đúng': ['chính xác', 'chuẩn', 'hợp lý'],
 'mua': ['sắm', 'đặt mua', 'lấy'],
 'bán': ['trao đổi', 'cung cấp', 'kinh doanh'],
 'đắt': ['cao giá', 'mắc', 'giá cao'],
 'rẻ': ['giá thấp', 'bình dân', 'phải chăng'],
 'ngon': ['tuyệt ngon', 'thơm ngon', 'đậm vị'],
 'dở': ['tệ', 'kém ngon', 'khó ăn'],
 'được': ['ổn', 'hợp lý', 'được việc'],
 'không': ['chẳng', 'không hề', 'không có'],

In [174]:
vi_stopwords[:5]

['a lô', 'a ha', 'ai', 'ai ai', 'ai nấy']

**2. Preprocess Data**

In [175]:
class EDA():
    def __init__(self, synonym_dict):
        self.synonym_dict = synonym_dict
    
    def process(self,tokens:list,eda_probs = 0.3) -> list:
        augmented_tokens = tokens
        
        if random.random() < eda_probs:
            augmented_tokens = self.synonym_augmentation(tokens,prob=0.4)
    
        if random.random() < eda_probs:
            augmented_tokens = self.random_insertion(augmented_tokens, n=1)        
    
        if random.random() < eda_probs:
            augmented_tokens = self.random_swap(augmented_tokens, n=1)
    
        if random.random() < eda_probs:
            augmented_tokens = self.random_deletion(augmented_tokens, word_threshold=3, p=0.1)
    
        return augmented_tokens
    
    def synonym_augmentation(self,tokens,prob=0.3):
        new_tokens = []
        for token in tokens:
            if token in self.synonym_dict and random.random() < prob:
                synonyms = self.synonym_dict[token]
                if synonyms:
                    new_token = random.choice(synonyms)
                    new_tokens.append(new_token)
                else:
                    new_tokens.append(token)
            else:
                new_tokens.append(token)
        return new_tokens
    
    def random_insertion(self,tokens, n=2):
        new_tokens = tokens.copy()
    
        for _ in range(n):
            token = random.choice(new_tokens).lower()
        
            if token not in self.synonym_dict:
                continue
        
            synonym = random.choice(synonym_dict[token])
            insert_pos = random.randint(0, len(new_tokens))
            new_tokens.insert(insert_pos, synonym)
        
        return new_tokens
    
    def random_swap(self,tokens, n= 1):
        len_tokens = len(tokens)
    
        if len_tokens < 2:
            return tokens
    
        new_tokens = tokens.copy()
    
        for _ in range(n):
            i = random.randint(0, len_tokens - 1)
            j = random.randint(0, len_tokens - 1)
        
            while i == j:
                j = random.randint(0, len_tokens - 1)
        
            new_tokens[i], new_tokens[j] = new_tokens[j], new_tokens[i]
        
        return new_tokens
    
    def random_deletion(self,tokens, word_threshold = 3, p=0.2):
        if len(tokens) < word_threshold:
            return tokens
    
        new_tokens = []
        for token in tokens:
            if random.random() > p:
                new_tokens.append(token)
    
        if len(new_tokens) == 0:
            return [random.choice(tokens)]
    
        return new_tokens
    
    

In [176]:
class DataPreprocess():
    def __init__(self,stopwords,tokenizer, emoji_labels, transform = None):
        self.stopwords = stopwords
        self.tokenizer = tokenizer
        
        self.emoji_labels = emoji_labels
        self.__emoji_sorted = sorted(self.emoji_labels.keys(), key=len, reverse=True)
        self.__emoji_regex = "|".join(re.escape(k) for k in self.__emoji_sorted)
        self.__pattern = re.compile(self.__emoji_regex, flags=re.IGNORECASE)

        self.vietnamese_pattern = r"[^a-zA-Z0-9\sàáảãạâầấẩẫậăằắẳẵặèéẻẽẹêềếểễệìíỉĩịòóỏõọôồốổỗộơờớởỡợùúủũụưừứửữựỳýỷỹỵđĐ]"

        self.transform = transform
        
    def process(self, text:str, apply_transform =True) -> list:
        
        text = text.strip()
        text = re.sub(r'\s+', ' ', text)
        text = text.lower()
        
        text = self.replace_emoji(text)

        placeholders = re.findall(r"<[^<>]+>", text)
        for i, ph in enumerate(placeholders):
            text = text.replace(ph, f"__PLACEHOLDER{i}__")
        
        text = re.sub(self.vietnamese_pattern, "", text, flags=re.UNICODE)
        
        for i, ph in enumerate(placeholders):
            text = text.replace(f"__PLACEHOLDER{i}__", ph)
                
        tokens = self.tokenize(text)
        
        if self.transform is not None and apply_transform == True:
            tokens = self.transform.process(tokens)
        
        tokens = self.remove_stopwords(tokens)
        tokens = self.remove_accents(tokens)
       
        #tokens = [t for t in tokens if not re.match(r'^\s+$', t)]
        return tokens
        
    def __repl(self,match:str):
        key = match.group().lower()
        return " " + self.emoji_labels[key] + " "
    
    def replace_emoji(self,text:str):
        return self.__pattern.sub(self.__repl,text)
        
    def remove_accents(self,tokens:list) -> list:
        new_tokens = []
        
        for token in tokens:
            nfkd = unicodedata.normalize('NFD', token)
            token = ''.join([c for c in nfkd if not unicodedata.category(c) == 'Mn'])
            new_tokens.append(token)    

        return new_tokens
    
    def remove_stopwords(self,tokens:list) -> list:
        new_tokens = []
        
        for token in tokens:
            if token in self.stopwords:
                continue
            
            new_tokens.append(token)
            
        return new_tokens
    
    def tokenize(self, text:str) -> list:
        tokens = [token.text for token in self.tokenizer(text)]
        
        cleaned_token = []
        for token in tokens:
            if re.match(r'^\s*$', token):
                continue
            
            if token in self.emoji_labels.values():
                cleaned_token.append(f"<{token}>")
                continue
        
            if re.search(r'\d', token):
                if re.fullmatch(r'\d+', token):
                    cleaned_token.append("<NUM>")
                else:
                    cleaned_token.append("<UNK>")
            else:
                cleaned_token.append(token)
                        
        return cleaned_token

In [177]:
tokenizer = spacy.blank('vi')

eda = EDA(synonym_dict=synonym_dict)
preprocess = DataPreprocess(stopwords=vi_stopwords,
                            tokenizer=tokenizer,
                            emoji_labels=emoji_labels,
                            transform=eda)

In [178]:
example = dataset['test'][170]['sentence'] + ":)) @@ 143 42342 colonsmilesmile"
print("example: " + example)
print("preprocess: ")
print(preprocess.process(example))

example: do trình độ tiếng anh của lớp không cao chỉ một số ít có khả năng nghe , đọc , và hiểu được bài giảng của thấy nên hiệu quả của việc giảng dạy bằng tiếng anh là chưa cao .:)) @@ 143 42342 colonsmilesmile
preprocess: 
['trinh đo', 'tieng', 'lop', 'mot so it', 'kha nang', 'đoc', 'on', 'giang', 'hieu qua', 'giang day', 'tieng', '<colonsmilesmile>', '<colonsurprise>', '<NUM>', '<NUM>', '<colonsmilesmile>']


**3. Build Vocab**

In [179]:
class MyVocab():
    def __init__(self,dataset = None,synonym_dict:dict = None,emoji_labels:dict = None,preprocess = None):
        self.preprocess = preprocess
        
        self.dataset = dataset
        self.synonym_dict = synonym_dict
        self.emoji_labels = emoji_labels
        
        self.__vocab = set()
        
        self.word2idx = {}
        self.idx2word = []
        
        self.vocab_size = None
        
    def get_vocab_size(self):
        return self.vocab_size
    
    def build_vocab(self):
        if self.preprocess is None or self.dataset is None or self.synonym_dict is None or self.emoji_labels is None:
            return None
        for key, syn_list in self.synonym_dict.items(): 
            key_norm = self.preprocess.remove_accents([key])[0]
            self.__vocab.add(key_norm)
        
        for v in syn_list:
            v_norm = self.preprocess.remove_accents([v])[0]
            self.__vocab.add(v_norm)

        for value in self.emoji_labels.values():
            self.__vocab.add(value)
            
        for _, row in self.dataset['train'].items():
            tokens = self.preprocess.process(row['sentence'],False)
            self.__vocab.update(tokens)
                
        special_tokens = ["<SOS>","<EOS>","<PAD>", "<UNK>","<NUM>"]
        
        final_vocab = special_tokens + sorted(self.__vocab)
       
        self.idx2word = final_vocab
        
        self.word2idx = {w: i for i, w in enumerate(final_vocab)}

        self.vocab_size = len(final_vocab)
        
        print("Vocab size:", self.vocab_size)
        return self.word2idx
    
    def save_vocab(self, path="vocab.txt"):
        with open(path, "w", encoding="utf-8") as f:
            for word in self.idx2word:
                f.write(word + "\n")
        
    def load_vocab(self, path="vocab.txt"):
        with open(path, "r", encoding="utf-8") as f:
            vocab = [line.strip() for line in f.readlines()]

        self.idx2word = vocab
        self.word2idx = {w: i for i, w in enumerate(vocab)}

        self.vocab_size = len(vocab)
        print("Loaded vocab size:", self.vocab_size)

    def mapping(self, tokens):
        return [self.word2idx.get(t, self.word2idx["<UNK>"]) for t in tokens]
    
    def reverse_mapping(self, indices: list) -> list:
        return [self.idx2word[i] if i < len(self.idx2word) else "<UNK>" for i in indices]

In [180]:
vocab = MyVocab(
    preprocess=preprocess,
    emoji_labels=emoji_labels,
    dataset=dataset,
    synonym_dict=synonym_dict
)

word2idx = vocab.build_vocab()
vocab.save_vocab("vocab.txt")

Vocab size: 2598


In [181]:
test_vocab = MyVocab()
test_vocab.load_vocab()

Loaded vocab size: 2598


In [182]:
example = dataset['dev'][32]['sentence']
print("example: " + example)
print("process: ")
test_process = preprocess.process(example,False)
print(test_process)
print('mapping: ' )
test_mapping = test_vocab.mapping(test_process)
print(test_mapping)
print('reverse_mapping: ')
print(test_vocab.reverse_mapping(test_mapping))

example: thầy dạy hay và có tâm , mặc dù môn học lịch sử khá khô khan nhưng chúng em hiểu được rất nhiều bài học qua bài giảng của thầy .
process: 
['thay', 'day', 'tam', 'mac du', 'mon hoc', 'lich su', 'kho khan', 'bai hoc', 'giang', 'thay']
mapping: 
[1875, 490, 1792, 1142, 1221, 1062, 930, 62, 649, 1875]
reverse_mapping: 
['thay', 'day', 'tam', 'mac du', 'mon hoc', 'lich su', 'kho khan', 'bai hoc', 'giang', 'thay']


**4. Build Dataset and DataLoader**

In [183]:
class VSFCDataset(Dataset):
    def __init__(self, data, preprocess, vocab):
        super().__init__()
        self.data = data
        self.preprocess = preprocess
        self.vocab = vocab
        self.sos_idx = vocab.word2idx["<SOS>"]
        self.eos_idx = vocab.word2idx["<EOS>"]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        tokens = self.preprocess.process(sample['sentence'])
        token_ids = self.vocab.mapping(tokens)
        token_ids = [self.sos_idx] + token_ids + [self.eos_idx]
        sentiment_label = int(sample['sentiment'])
        topic_label = int(sample['topic'])
        
        return {
            'input_ids': torch.tensor(token_ids, dtype=torch.long),
            'sentiment': torch.tensor(sentiment_label, dtype=torch.long),
            'topic': torch.tensor(topic_label, dtype=torch.long)
        }


def collate_fn(batch, pad_idx):
    
    input_ids = [b['input_ids'] for b in batch]
    sentiment = torch.stack([b['sentiment'] for b in batch])
    topic = torch.stack([b['topic'] for b in batch])
    
    max_len = max(seq.size(0) for seq in input_ids)
    padded_inputs = []
    for seq in input_ids:
        pad_len = max_len - seq.size(0)
        padded_seq = torch.cat([seq, torch.full((pad_len,), pad_idx, dtype=torch.long)])
        padded_inputs.append(padded_seq)
    
    input_ids = torch.stack(padded_inputs) 
    
    return {'input_ids': input_ids, 'sentiment': sentiment, 'topic': topic}

    

In [184]:
pad_idx = int(vocab.word2idx["<PAD>"])

train_dataset = VSFCDataset(dataset['train'], preprocess, vocab)

train_loader = DataLoader(
    train_dataset,
    batch_size=200,
    shuffle=True,
    collate_fn=lambda x: collate_fn(x, pad_idx),
    drop_last=True
)

test_dataset = VSFCDataset(dataset['dev'], preprocess, vocab)
val_dataset = VSFCDataset(dataset['test'], preprocess, vocab)

test_loader = DataLoader(
    test_dataset,
    batch_size=200,
    shuffle=False,          
    collate_fn=lambda x: collate_fn(x, pad_idx)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=200,
    shuffle=True,
    collate_fn=lambda x: collate_fn(x, pad_idx)
)


**5. Build Model**

In [185]:
class VSFCClassifier(nn.Module):
    def __init__(self, vocab_size,  sentiment_classes , topic_classes, pad_idx,embed_dim = 128, hidden_dim =256):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=pad_idx
            )
        
        self.encoder = nn.LSTM(
            input_size= embed_dim,
            hidden_size=hidden_dim,
            batch_first=True, 
            bidirectional=True
            )
        
        self.attn = nn.Linear(
            in_features=hidden_dim*2,
            out_features=1
            )
        
        self.sentiment_head = nn.Linear(
            in_features=hidden_dim*2,
            out_features=sentiment_classes
        )
        
        self.topic_head = nn.Linear(
            in_features = hidden_dim*2,
            out_features = topic_classes
        )
        
    def forward(self,input_ids):
        x = self.embedding(input_ids)
        enc_out, _ = self.encoder(x)
        
        attn_weights = torch.softmax(self.attn(enc_out), dim= 1)
        pooled = torch.sum(enc_out*attn_weights,dim = 1 )
        
        sentiment_logits = self.sentiment_head(pooled)
        topic_logits = self.topic_head(pooled)
        
        return sentiment_logits, topic_logits

In [186]:
model = VSFCClassifier(
    vocab_size=vocab.get_vocab_size(),
    sentiment_classes=len(sentiment_labels),
    topic_classes=len(topic_labels),
    pad_idx=vocab.word2idx["<PAD>"],
    )

**6. Train**

In [187]:
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Current device index:", torch.cuda.current_device())
    print("Memory allocated:", torch.cuda.memory_allocated(0))
    print("Memory cached:", torch.cuda.memory_reserved(0))

CUDA available: True
GPU name: NVIDIA GeForce RTX 3060
Current device index: 0
Memory allocated: 21549568
Memory cached: 1472200704


In [188]:
def save_model(model, path="saved_model/vsfc_model.pth"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Model saved to: {path}")
    
def load_model(model_class, vocab_size, embed_dim, hidden_dim, 
               sentiment_classes, topic_classes, path, device,pad_idx):

    model = model_class(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        hidden_dim=hidden_dim,
        sentiment_classes=sentiment_classes,
        topic_classes=topic_classes,
        pad_idx = pad_idx
    )

    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    model.eval()
    
    return model

In [189]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_loader, val_loader, device, num_epochs=5, lr=1e-3,save_plot_dir="plots"):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model.to(device)
    
    os.makedirs(save_plot_dir, exist_ok=True)
    
    train_losses = []
    val_sent_accs = []
    val_topic_accs = []
    
    for epoch in range(num_epochs):

        model.train()
        total_loss = 0
        
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            sentiment_labels = batch['sentiment'].to(device)
            topic_labels = batch['topic'].to(device)
            
            optimizer.zero_grad()
            s_logits, t_logits = model(input_ids)
            loss = criterion(s_logits, sentiment_labels) + criterion(t_logits, topic_labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        model.eval()
        correct_sent, total_sent = 0, 0
        correct_topic, total_topic = 0, 0
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                sentiment_labels = batch['sentiment'].to(device)
                topic_labels = batch['topic'].to(device)
                
                s_logits, t_logits = model(input_ids)
                s_pred = s_logits.argmax(dim=1)
                t_pred = t_logits.argmax(dim=1)
                
                correct_sent += (s_pred == sentiment_labels).sum().item()
                total_sent += sentiment_labels.size(0)
                
                correct_topic += (t_pred == topic_labels).sum().item()
                total_topic += topic_labels.size(0)
        
        sent_acc = correct_sent / total_sent
        topic_acc = correct_topic / total_topic
        val_sent_accs.append(sent_acc)
        val_topic_accs.append(topic_acc)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss={avg_loss:.4f}, "
              f"Val Sentiment Acc={sent_acc:.4f}, Val Topic Acc={topic_acc:.4f}")
    
    save_model(model=model)

    
    epochs = range(1, num_epochs+1)
    
    plt.figure(figsize=(8,5))
    plt.plot(epochs, train_losses, marker='o', label='Train Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_plot_dir, "train_loss.png"))
    plt.close()
    
    plt.figure(figsize=(8,5))
    plt.plot(epochs, val_sent_accs, marker='o', label='Val Sentiment Acc')
    plt.plot(epochs, val_topic_accs, marker='s', label='Val Topic Acc')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Validation Accuracy")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_plot_dir, "val_accuracy.png"))
    plt.close()
  

In [190]:
train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_epochs=100)

Epoch 1/100, Loss=1.4385, Val Sentiment Acc=0.7243, Val Topic Acc=0.7934
Epoch 2/100, Loss=1.0176, Val Sentiment Acc=0.7612, Val Topic Acc=0.8225
Epoch 3/100, Loss=0.9197, Val Sentiment Acc=0.7720, Val Topic Acc=0.8313
Epoch 4/100, Loss=0.8661, Val Sentiment Acc=0.7713, Val Topic Acc=0.8354
Epoch 5/100, Loss=0.8113, Val Sentiment Acc=0.7751, Val Topic Acc=0.8380
Epoch 6/100, Loss=0.7740, Val Sentiment Acc=0.7780, Val Topic Acc=0.8440
Epoch 7/100, Loss=0.7315, Val Sentiment Acc=0.7846, Val Topic Acc=0.8364
Epoch 8/100, Loss=0.6877, Val Sentiment Acc=0.7836, Val Topic Acc=0.8399
Epoch 9/100, Loss=0.6437, Val Sentiment Acc=0.7827, Val Topic Acc=0.8395
Epoch 10/100, Loss=0.6155, Val Sentiment Acc=0.7900, Val Topic Acc=0.8326
Epoch 11/100, Loss=0.5733, Val Sentiment Acc=0.7858, Val Topic Acc=0.8354
Epoch 12/100, Loss=0.5582, Val Sentiment Acc=0.7858, Val Topic Acc=0.8339
Epoch 13/100, Loss=0.5155, Val Sentiment Acc=0.7846, Val Topic Acc=0.8332
Epoch 14/100, Loss=0.4857, Val Sentiment Acc=0.

**7. Test**

In [192]:
saved_model = load_model(
    model_class = VSFCClassifier,
    vocab_size=vocab.get_vocab_size(),
    embed_dim=128,
    hidden_dim=256,
    sentiment_classes=len(sentiment_labels),
    topic_classes=len(topic_labels),
    path='saved_model/vsfc_model.pth',
    device=device,
    pad_idx=vocab.word2idx["<PAD>"]
    )

def test(model, test_loader, device, save_plot_dir="plots"):
    model.eval()
    
    correct_sent, total_sent = 0, 0
    correct_topic, total_topic = 0, 0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            sentiment_labels = batch["sentiment"].to(device)
            topic_labels = batch["topic"].to(device)

            s_logits, t_logits = model(input_ids)

            s_pred = s_logits.argmax(dim=1)
            t_pred = t_logits.argmax(dim=1)

            correct_sent += (s_pred == sentiment_labels).sum().item()
            total_sent += sentiment_labels.size(0)

            correct_topic += (t_pred == topic_labels).sum().item()
            total_topic += topic_labels.size(0)

    sent_acc = correct_sent / total_sent
    topic_acc = correct_topic / total_topic

    print(f"\n========== TEST RESULT ==========")
    print(f"Sentiment Accuracy: {sent_acc:.4f}")
    print(f"Topic Accuracy:     {topic_acc:.4f}")
    print("=================================\n")

    os.makedirs(save_plot_dir, exist_ok=True)

    categories = ["Sentiment", "Topic"]
    accuracies = [sent_acc, topic_acc]

    plt.figure(figsize=(6, 4))
    plt.bar(categories, accuracies)
    plt.ylim(0, 1.0)
    plt.title("Test Accuracy")
    plt.ylabel("Accuracy")

    filename = os.path.join(save_plot_dir, "test_accuracy.png")
    plt.savefig(filename)
    plt.close()

    return sent_acc, topic_acc

  model.load_state_dict(torch.load(path, map_location=device))


In [193]:
test(
    model = saved_model,
    test_loader=test_loader, 
    device=device)


Sentiment Accuracy: 0.7896
Topic Accuracy:     0.8345



(0.7896399241945673, 0.8344914718888187)