In [16]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score

In [29]:
model_name = "vinai/phobert-base"   # or "xlm-roberta-base"
max_len = 128
batch_size = 8
lr = 2e-5
num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
freeze_encoder = True     # True cho dataset nhỏ; False nếu muốn fine-tune encoder
attn_num_heads = 4
attn_dropout = 0.2
save_path = "best_bert_attn.pt"

print("Device:", device)

Device: cuda


In [18]:
class QRCdataset(Dataset):
    def __init__(self, df, tokenizer, max_len=128, include_response=True):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.texts = []
        self.labels = df["label"].tolist()
        for _, r in df.iterrows():
            context = str(r.get("context",""))
            prompt  = str(r.get("prompt",""))
            response = str(r.get("response","")) if include_response else ""
            # build single string with simple separators
            txt = f"[CTX] {context} [Q] {prompt} [R] {response}"
            self.texts.append(txt)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        item = {k: enc[k].squeeze(0) for k in enc}
        item["label"] = self.labels[idx]
        return item

def collate_fn(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch])
    attention_mask = torch.stack([b["attention_mask"] for b in batch])
    labels = [b["label"] for b in batch]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [19]:
class SelfAttentionExtractor(nn.Module):
    """
    Multi-head self-attention applied to encoder token embeddings.
    Returns:
      - attn_output: (B, L, H) (residual of attn)
      - attn_weights: (B, num_heads, L, L)  (raw attention weights)
    We'll pool weights over heads and take weights for CLS-like pooling.
    """
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.layernorm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, hidden_states, attention_mask):
        # hidden_states: (B, L, H)
        # attention_mask: (B, L) with 1 for token, 0 for pad
        # Build key_padding_mask for MultiheadAttention: True for positions to **ignore**
        key_padding_mask = (attention_mask == 0)  # (B, L) bool
        attn_out, attn_weights = self.mha(hidden_states, hidden_states, hidden_states,
                                          key_padding_mask=key_padding_mask)  # attn_weights (B, L, L) or (B, num_heads, L, L) depending version
        out = self.layernorm(hidden_states + self.dropout(attn_out))
        # ensure attn_weights shape: (B, num_heads, L, L) if returned (B, L, L) convert to (B,1,L,L)
        if attn_weights.dim() == 3:
            attn_weights = attn_weights.unsqueeze(1)  # (B,1,L,L)
        return out, attn_weights

In [20]:
class SmallTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead=4, dim_feedforward=2048, nlayers=1, dropout=0.1):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, x, attention_mask):
        # x: (B,L,H). TransformerEncoder uses src_key_padding_mask where True = pad
        src_key_padding_mask = (attention_mask == 0)
        out = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        out = self.layernorm(out)
        return out

In [21]:
class BertAttnLLMClassifier(nn.Module):
    def __init__(self, encoder_name, num_classes, freeze_encoder=True,
                 attn_heads=4, attn_dropout=0.1, small_enc_layers=1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        embed_dim = self.encoder.config.hidden_size
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False

        self.attn_extractor = SelfAttentionExtractor(embed_dim, num_heads=attn_heads, dropout=attn_dropout)
        # a small transformer acting as LLM-like encoder on top of attn output
        self.small_encoder = SmallTransformerEncoder(d_model=embed_dim, nhead=min(attn_heads,8), nlayers=small_enc_layers)
        # classification head: pool then MLP
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(embed_dim // 2, num_classes)
        )

    def forward(self, input_ids, attention_mask, return_attn=False):
        # Get Transformer encoder outputs (last_hidden_state)
        enc_out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last_hidden = enc_out.last_hidden_state     # (B, L, H)

        # Self-attention extractor (gives refined hidden + weights)
        attn_out, attn_weights = self.attn_extractor(last_hidden, attention_mask)  # attn_out: (B,L,H)

        # Small transformer encoder (LLM-like)
        small_out = self.small_encoder(attn_out, attention_mask)  # (B,L,H)

        # Pooling: mean pooled over valid tokens weighted by mask
        mask = attention_mask.unsqueeze(-1)  # (B,L,1)
        sum_hidden = (small_out * mask).sum(dim=1)               # (B,H)
        denom = mask.sum(dim=1).clamp(min=1e-9)
        pooled = sum_hidden / denom                              # (B,H)

        logits = self.classifier(pooled)                         # (B,num_classes)

        if return_attn:
            # produce per-token importance by averaging heads and taking attention from [all->token] or token->all
            # attn_weights shape: (B, num_heads, L, L)
            attn_avg = attn_weights.mean(dim=1)  # (B, L, L)
            # For token importance we can take attention paid *to* each token: mean over source positions
            token_imp = attn_avg.mean(dim=1)   # (B, L)
            return logits, token_imp, attn_weights
        return logits

In [22]:
def prepare_datasets(train_csv, val_csv, test_csv, tokenizer, include_response=True):
    df_train = pd.read_csv(train_csv).fillna("")
    df_val   = pd.read_csv(val_csv).fillna("")
    df_test  = pd.read_csv(test_csv).fillna("")

    # Label encoding consistently across sets
    le = LabelEncoder()
    le.fit(list(df_train["label"].unique()) + list(df_val["label"].unique()) + list(df_test["label"].unique()))

    # store string labels back to csv datasets (model input uses strings)
    df_train["label"] = le.transform(df_train["label"])
    df_val["label"]   = le.transform(df_val["label"])
    df_test["label"]  = le.transform(df_test["label"])

    train_ds = QRCdataset(df_train, tokenizer, max_len=max_len, include_response=include_response)
    val_ds = QRCdataset(df_val, tokenizer, max_len=max_len, include_response=include_response)
    test_ds = QRCdataset(df_test, tokenizer, max_len=max_len, include_response=include_response)

    return train_ds, val_ds, test_ds, le

In [24]:
def evaluate_model(model, loader, device):
    model.eval()
    preds, trues = [], []
    losses = []
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = torch.tensor(batch["labels"], dtype=torch.long).to(device) if isinstance(batch["labels"], list) else batch["labels"].to(device)
            logits = model(input_ids=input_ids, attention_mask=mask)
            loss = criterion(logits, labels)
            losses.append(loss.item())
            preds.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist())
            trues.extend(labels.cpu().numpy().tolist())
    avg_loss = np.mean(losses) if losses else 0.0
    return avg_loss, preds, trues

def train_loop(train_loader, val_loader, model, optimizer, scheduler, device, le):
    best_f1 = 0.0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(1, num_epochs+1):
        model.train()
        total_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
            input_ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = torch.tensor(batch["labels"], dtype=torch.long).to(device) if isinstance(batch["labels"], list) else batch["labels"].to(device)
            optimizer.zero_grad()
            logits = model(input_ids=input_ids, attention_mask=mask)
            loss = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            total_loss += loss.item()

        train_loss = total_loss / len(train_loader)
        val_loss, val_preds, val_trues = evaluate_model(model, val_loader, device)
        prf = precision_recall_fscore_support(val_trues, val_preds, average="macro", zero_division=0)
        acc = accuracy_score(val_trues, val_preds)
        print(f"Epoch {epoch}: TrainLoss={train_loss:.4f} ValLoss={val_loss:.4f} Acc={acc:.4f} MacroF1={prf[2]:.4f}")

        if prf[2] > best_f1:
            best_f1 = prf[2]
            torch.save(model.state_dict(), save_path)
            print("Saved best model, best macro-F1:", best_f1)

In [25]:
def predict_with_attention(model, tokenizer, text, device, max_len=128):
    model.eval()
    enc = tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    mask = enc["attention_mask"].to(device)
    with torch.no_grad():
        logits, token_imp, attn_weights = model(input_ids=input_ids, attention_mask=mask, return_attn=True)
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
        token_imp = token_imp.cpu().numpy()[0]  # (L,)
        attn_weights = attn_weights.cpu().numpy()[0]  # (num_heads, L, L)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
    return probs, tokens, token_imp, attn_weights

In [30]:
if __name__ == "__main__":
    # paths
    train_csv = "train.csv"
    val_csv = "val.csv"
    test_csv = "test.csv"

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    train_ds, val_ds, test_ds, le = prepare_datasets(train_csv, val_csv, test_csv, tokenizer, include_response=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    num_classes = len(le.classes_)
    model = BertAttnLLMClassifier(model_name, num_classes=num_classes, freeze_encoder=freeze_encoder,
                                  attn_heads=attn_num_heads, attn_dropout=attn_dropout, small_enc_layers=1).to(device)

    # optimizer + scheduler
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-2)
    total_steps = num_epochs * len(train_loader)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)

    # train
    train_loop(train_loader, val_loader, model, optimizer, scheduler, device, le)

    # load best and evaluate on test
    model.load_state_dict(torch.load(save_path, map_location=device))
    test_loss, test_preds, test_trues = evaluate_model(model, test_loader, device)
    print("=== Test Report ===")
    print(classification_report(test_trues, test_preds, target_names=le.classes_, zero_division=0))

    # example inference with attention
    sample_text = pd.read_csv(test_csv).iloc[0]
    text = f"[CTX] {sample_text['context']} [Q] {sample_text['prompt']} [R] {sample_text['response']}"
    probs, tokens, token_imp, attn_weights = predict_with_attention(model, tokenizer, text, device, max_len=max_len)
    print("Pred probs:", probs)
    # show top-k tokens by importance
    topk = np.argsort(token_imp)[-10:][::-1]
    for idx in topk:
        print(idx, tokens[idx], float(token_imp[idx]))

Epoch 1/50: 100%|██████████| 18/18 [00:01<00:00, 13.97it/s]


Epoch 1: TrainLoss=1.1022 ValLoss=1.1052 Acc=0.3750 MacroF1=0.4042
Saved best model, best macro-F1: 0.4041514041514041


Epoch 2/50: 100%|██████████| 18/18 [00:01<00:00, 15.31it/s]


Epoch 2: TrainLoss=1.0784 ValLoss=1.1083 Acc=0.5000 MacroF1=0.4202
Saved best model, best macro-F1: 0.42020202020202024


Epoch 3/50: 100%|██████████| 18/18 [00:01<00:00, 14.95it/s]


Epoch 3: TrainLoss=1.0566 ValLoss=1.0948 Acc=0.4375 MacroF1=0.4481
Saved best model, best macro-F1: 0.4481481481481482


Epoch 4/50: 100%|██████████| 18/18 [00:01<00:00, 15.03it/s]


Epoch 4: TrainLoss=1.0382 ValLoss=1.0926 Acc=0.3750 MacroF1=0.3889


Epoch 5/50: 100%|██████████| 18/18 [00:01<00:00, 15.22it/s]


Epoch 5: TrainLoss=1.0157 ValLoss=1.0934 Acc=0.4375 MacroF1=0.3796


Epoch 6/50: 100%|██████████| 18/18 [00:01<00:00, 14.98it/s]


Epoch 6: TrainLoss=1.0051 ValLoss=1.0805 Acc=0.3750 MacroF1=0.3397


Epoch 7/50: 100%|██████████| 18/18 [00:01<00:00, 15.38it/s]


Epoch 7: TrainLoss=0.9812 ValLoss=1.0774 Acc=0.3125 MacroF1=0.3296


Epoch 8/50: 100%|██████████| 18/18 [00:01<00:00, 13.69it/s]


Epoch 8: TrainLoss=0.9574 ValLoss=1.0905 Acc=0.3750 MacroF1=0.2804


Epoch 9/50: 100%|██████████| 18/18 [00:01<00:00, 15.31it/s]


Epoch 9: TrainLoss=0.9288 ValLoss=1.0787 Acc=0.5625 MacroF1=0.5305
Saved best model, best macro-F1: 0.5304843304843305


Epoch 10/50: 100%|██████████| 18/18 [00:01<00:00, 14.97it/s]


Epoch 10: TrainLoss=0.9054 ValLoss=1.0676 Acc=0.4375 MacroF1=0.3776


Epoch 11/50: 100%|██████████| 18/18 [00:01<00:00, 15.26it/s]


Epoch 11: TrainLoss=0.8828 ValLoss=1.0631 Acc=0.4375 MacroF1=0.3776


Epoch 12/50: 100%|██████████| 18/18 [00:01<00:00, 14.87it/s]


Epoch 12: TrainLoss=0.8581 ValLoss=1.0593 Acc=0.3750 MacroF1=0.3397


Epoch 13/50: 100%|██████████| 18/18 [00:01<00:00, 15.33it/s]


Epoch 13: TrainLoss=0.8223 ValLoss=1.0484 Acc=0.3750 MacroF1=0.3397


Epoch 14/50: 100%|██████████| 18/18 [00:01<00:00, 15.29it/s]


Epoch 14: TrainLoss=0.7867 ValLoss=1.0463 Acc=0.3750 MacroF1=0.3463


Epoch 15/50: 100%|██████████| 18/18 [00:01<00:00, 15.02it/s]


Epoch 15: TrainLoss=0.7706 ValLoss=1.0426 Acc=0.4375 MacroF1=0.4185


Epoch 16/50: 100%|██████████| 18/18 [00:01<00:00, 15.29it/s]


Epoch 16: TrainLoss=0.7369 ValLoss=1.0437 Acc=0.3750 MacroF1=0.3463


Epoch 17/50: 100%|██████████| 18/18 [00:01<00:00, 15.32it/s]


Epoch 17: TrainLoss=0.7042 ValLoss=1.0356 Acc=0.4375 MacroF1=0.4185


Epoch 18/50: 100%|██████████| 18/18 [00:01<00:00, 15.29it/s]


Epoch 18: TrainLoss=0.6753 ValLoss=1.0398 Acc=0.4375 MacroF1=0.4238


Epoch 19/50: 100%|██████████| 18/18 [00:01<00:00, 15.08it/s]


Epoch 19: TrainLoss=0.6384 ValLoss=1.0325 Acc=0.5000 MacroF1=0.4940


Epoch 20/50: 100%|██████████| 18/18 [00:01<00:00, 15.29it/s]


Epoch 20: TrainLoss=0.6152 ValLoss=1.0480 Acc=0.4375 MacroF1=0.4185


Epoch 21/50: 100%|██████████| 18/18 [00:01<00:00, 14.93it/s]


Epoch 21: TrainLoss=0.5697 ValLoss=1.0541 Acc=0.5000 MacroF1=0.4868


Epoch 22/50: 100%|██████████| 18/18 [00:01<00:00, 15.01it/s]


Epoch 22: TrainLoss=0.5543 ValLoss=1.0710 Acc=0.3750 MacroF1=0.3397


Epoch 23/50: 100%|██████████| 18/18 [00:01<00:00, 14.94it/s]


Epoch 23: TrainLoss=0.5075 ValLoss=1.0598 Acc=0.5000 MacroF1=0.4930


Epoch 24/50: 100%|██████████| 18/18 [00:01<00:00, 14.98it/s]


Epoch 24: TrainLoss=0.4956 ValLoss=1.0824 Acc=0.5000 MacroF1=0.4868


Epoch 25/50: 100%|██████████| 18/18 [00:01<00:00, 14.88it/s]


Epoch 25: TrainLoss=0.4492 ValLoss=1.0922 Acc=0.5000 MacroF1=0.4868


Epoch 26/50: 100%|██████████| 18/18 [00:01<00:00, 15.08it/s]


Epoch 26: TrainLoss=0.4406 ValLoss=1.1365 Acc=0.3750 MacroF1=0.3397


Epoch 27/50: 100%|██████████| 18/18 [00:01<00:00, 15.09it/s]


Epoch 27: TrainLoss=0.4006 ValLoss=1.1124 Acc=0.5000 MacroF1=0.4868


Epoch 28/50: 100%|██████████| 18/18 [00:01<00:00, 15.21it/s]


Epoch 28: TrainLoss=0.3732 ValLoss=1.0963 Acc=0.5000 MacroF1=0.4930


Epoch 29/50: 100%|██████████| 18/18 [00:01<00:00, 14.92it/s]


Epoch 29: TrainLoss=0.3667 ValLoss=1.1511 Acc=0.4375 MacroF1=0.4185


Epoch 30/50: 100%|██████████| 18/18 [00:01<00:00, 15.18it/s]


Epoch 30: TrainLoss=0.3219 ValLoss=1.1052 Acc=0.5625 MacroF1=0.5607
Saved best model, best macro-F1: 0.5606837606837607


Epoch 31/50: 100%|██████████| 18/18 [00:01<00:00, 14.69it/s]


Epoch 31: TrainLoss=0.2983 ValLoss=1.1568 Acc=0.4375 MacroF1=0.4185


Epoch 32/50: 100%|██████████| 18/18 [00:01<00:00, 14.99it/s]


Epoch 32: TrainLoss=0.2560 ValLoss=1.1504 Acc=0.5000 MacroF1=0.4930


Epoch 33/50: 100%|██████████| 18/18 [00:01<00:00, 15.08it/s]


Epoch 33: TrainLoss=0.2719 ValLoss=1.1730 Acc=0.5625 MacroF1=0.5623
Saved best model, best macro-F1: 0.5622710622710623


Epoch 34/50: 100%|██████████| 18/18 [00:01<00:00, 14.90it/s]


Epoch 34: TrainLoss=0.2373 ValLoss=1.2324 Acc=0.5000 MacroF1=0.5016


Epoch 35/50: 100%|██████████| 18/18 [00:01<00:00, 14.49it/s]


Epoch 35: TrainLoss=0.2328 ValLoss=1.1709 Acc=0.5625 MacroF1=0.5607


Epoch 36/50: 100%|██████████| 18/18 [00:01<00:00, 14.64it/s]


Epoch 36: TrainLoss=0.1897 ValLoss=1.3031 Acc=0.5000 MacroF1=0.5016


Epoch 37/50: 100%|██████████| 18/18 [00:01<00:00, 14.83it/s]


Epoch 37: TrainLoss=0.1981 ValLoss=1.2159 Acc=0.5000 MacroF1=0.4905


Epoch 38/50: 100%|██████████| 18/18 [00:01<00:00, 14.74it/s]


Epoch 38: TrainLoss=0.1559 ValLoss=1.2845 Acc=0.5000 MacroF1=0.5016


Epoch 39/50: 100%|██████████| 18/18 [00:01<00:00, 14.82it/s]


Epoch 39: TrainLoss=0.1624 ValLoss=1.3196 Acc=0.5000 MacroF1=0.5016


Epoch 40/50: 100%|██████████| 18/18 [00:01<00:00, 14.95it/s]


Epoch 40: TrainLoss=0.1623 ValLoss=1.3512 Acc=0.5000 MacroF1=0.5016


Epoch 41/50: 100%|██████████| 18/18 [00:01<00:00, 14.93it/s]


Epoch 41: TrainLoss=0.1366 ValLoss=1.3012 Acc=0.5000 MacroF1=0.5016


Epoch 42/50: 100%|██████████| 18/18 [00:01<00:00, 14.78it/s]


Epoch 42: TrainLoss=0.1234 ValLoss=1.3299 Acc=0.5000 MacroF1=0.5016


Epoch 43/50: 100%|██████████| 18/18 [00:01<00:00, 14.99it/s]


Epoch 43: TrainLoss=0.1344 ValLoss=1.3595 Acc=0.5000 MacroF1=0.5016


Epoch 44/50: 100%|██████████| 18/18 [00:01<00:00, 14.77it/s]


Epoch 44: TrainLoss=0.1330 ValLoss=1.4305 Acc=0.4375 MacroF1=0.4312


Epoch 45/50: 100%|██████████| 18/18 [00:01<00:00, 14.78it/s]


Epoch 45: TrainLoss=0.1167 ValLoss=1.3088 Acc=0.5625 MacroF1=0.5623


Epoch 46/50: 100%|██████████| 18/18 [00:01<00:00, 14.73it/s]


Epoch 46: TrainLoss=0.1153 ValLoss=1.3247 Acc=0.5625 MacroF1=0.5623


Epoch 47/50: 100%|██████████| 18/18 [00:01<00:00, 14.92it/s]


Epoch 47: TrainLoss=0.1045 ValLoss=1.3863 Acc=0.5000 MacroF1=0.5016


Epoch 48/50: 100%|██████████| 18/18 [00:01<00:00, 14.78it/s]


Epoch 48: TrainLoss=0.0950 ValLoss=1.3827 Acc=0.5000 MacroF1=0.5016


Epoch 49/50: 100%|██████████| 18/18 [00:01<00:00, 14.89it/s]


Epoch 49: TrainLoss=0.0847 ValLoss=1.3862 Acc=0.5000 MacroF1=0.5022


Epoch 50/50: 100%|██████████| 18/18 [00:01<00:00, 14.78it/s]


Epoch 50: TrainLoss=0.0988 ValLoss=1.4549 Acc=0.5000 MacroF1=0.5016
=== Test Report ===
              precision    recall  f1-score   support

   extrinsic       0.50      0.69      0.58        13
   intrinsic       0.62      0.62      0.62        13
          no       0.44      0.29      0.35        14

    accuracy                           0.53        40
   macro avg       0.52      0.53      0.51        40
weighted avg       0.52      0.53      0.51        40

Pred probs: [0.82819843 0.15018632 0.02161524]
8 chuyển 0.010460099205374718
9 tiếp 0.009309806860983372
16 đồng 0.008942298591136932
75 mét 0.008940604515373707
111 mé@@ 0.008901253342628479
32 và 0.008878671564161777
17 bằng 0.008838163688778877
93 - 0.008818496949970722
30 xuống 0.008793191984295845
118 mé@@ 0.008790582418441772


In [12]:
# ---------- Evaluate on test set + print per-class report ----------
model.load_state_dict(torch.load("best_model_attn.pt"))
test_loss, test_preds, test_trues = evaluate_model(model, test_loader, device)
print("Test Loss:", test_loss)
print(classification_report(test_trues, test_preds, target_names=le.classes_))


Test Loss: 1.0029139161109923
              precision    recall  f1-score   support

   extrinsic       0.50      0.31      0.38        13
   intrinsic       0.50      0.69      0.58        13
          no       0.50      0.50      0.50        14

    accuracy                           0.50        40
   macro avg       0.50      0.50      0.49        40
weighted avg       0.50      0.50      0.49        40



In [13]:
# ---------- Inference example: trả về attention scores để giải thích ----------
def predict_with_attention(model, tokenizer, text, device, max_len=128):
    model.eval()
    enc = tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    mask = enc["attention_mask"].to(device)
    with torch.no_grad():
        logits, attn_weights = model(input_ids=input_ids, attention_mask=mask, return_attn=True)
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
        weights = attn_weights.cpu().numpy()[0]  # (seq_len,)
    # map tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
    return probs, tokens, weights

# ví dụ:
text = df_test.iloc[0]["context"] + " " + df_test.iloc[0]["prompt"] + " " + df_test.iloc[0]["response"]
probs, tokens, weights = predict_with_attention(model, tokenizer, text, device)
# show top tokens with high attention
topk = np.argsort(weights)[-10:][::-1]
for idx in topk:
    print(idx, tokens[idx], weights[idx])
print("Pred probs:", probs)


12 đồng 0.008845827
5 tiếp 0.0087670665
13 bằng 0.008642147
18 địa 0.008587842
110 nơi 0.008553644
122 , 0.0084927995
4 chuyển 0.008469682
51 bình 0.008415734
19 hình 0.008396566
118 vực 0.00839231
Pred probs: [0.33478364 0.46328855 0.2019278 ]
