In [1]:
from torch.nn.utils.rnn import pad_sequence
import re
import string
import json
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from torchmetrics.classification import MulticlassF1Score
import sklearn.metrics
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [2]:
import math
from math import *
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import os
from glob import glob
from tqdm.auto import tqdm
from torch.optim import AdamW
import matplotlib.pyplot as plt
from collections import defaultdict

In [4]:
from transformers import set_seed
import random

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
set_seed(SEED)

## Preprocess Data

In [5]:
def tok_to_idx(toks):
    return{tok: idx for idx, tok in enumerate(toks)}

def idx_to_tok(toks):
    idx_dict = tok_to_idx(toks)
    return {idx: tok for tok, idx in idx_dict.items()}

In [6]:
def tokenization(data, without_comma):
    output = []
    for dic in data:
        tok_re = re.compile(r"\b\w+\b|[^\w\s]")
        tok = tok_re.findall(dic['FOL2NS'])
        if without_comma:
            tok = [t for t in tok if t!="," ]
        tok = ["<bos>"] + tok
        output.append(tok)
    return output

In [7]:
def detokenization(idx_seq, idx2tok):
    tokens = [idx2tok[i] for i in idx_seq]
    return " ".join(tokens)

In [8]:
def mapping(data, tok2idx, without_comma):
    output = []
    tokizer = tokenization(data, without_comma)
    for i in tokizer:
        item = []
        for tok in i:
            idx = tok2idx[tok]
            item.append(idx)
        output.append(item)
    return output

In [9]:
def reconstruct_data(data, tok2idx, without_comma):
    mapped_ids = mapping(data, tok2idx, without_comma)
    for org, ids in zip(data, mapped_ids):
        org["input_ids"] = ids
    return data

In [10]:
def padding(batch_data, tok2idx):
    FOL, QD = zip(*batch_data)
    labels = torch.tensor(QD, dtype=torch.long)
    FOL_ts = [torch.tensor(f, dtype=torch.long) for f in FOL]
    inputs = pad_sequence(FOL_ts, batch_first=True, padding_value=tok2idx['<pad>'])
    
    return {"input_ids": inputs,
            "labels": labels}

In [11]:
def preprocess(data_name, without_comma):
    data = []
    with open(data_name, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line)
            data.append(item)
            
    data_new = reconstruct_data(data, tok2idx, without_comma)
    inputs = [item["input_ids"] for item in data_new]
    labels = [(int(item["QD"])-1) for item in data_new]
    pairs = list(zip(inputs, labels))
    train_pairs, valid_pairs = train_test_split(pairs, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_pairs, batch_size=32, drop_last = True, collate_fn = lambda batch: padding(batch, tok2idx))
    valid_loader = DataLoader(valid_pairs, batch_size=32, drop_last = True, collate_fn = lambda batch: padding(batch, tok2idx))
    
    return train_loader, valid_loader

In [45]:
# 128
person = [
    'Accountants', 'Actors', 'Actuaries', 'Adults', 'Advisors', 'Agents', 'Allergists', 'Analysts',
    'Anthropologists', 'Archaeologists', 'Artists', 'Astronomers', 'Athletes', 'Attackers', 'Audiologists', 'Auditors',
    'Babies', 'Bailiffs', 'Bakers', 'Ballerinas', 'Barbers', 'Bartenders', 'Bloggers', 'Boxers',
    'Breadwinners', 'Butchers', 'Butlers', 'Captains', 'Cartographers', 'Cashiers', 'Chiropractors', 'Cleaners',
    'Clerks', 'Conductors', 'Cooks', 'Cricketers', 'Crooks', 'Cyclists', 'Cynics', 'Dancers',
    'Defenders', 'Dentists', 'Directors', 'Drillers', 'Drivers', 'Economists', 'Electricians', 'Engineers',
    'Epidemiologists', 'Experts', 'Farmers', 'Fighters', 'Firemen', 'Fishermen', 'Footballers', 'Foresters',
    'Ghosts', 'Grandmasters', 'Guests', 'Gymnasts', 'Hairdressers', 'Helpers', 'Historians', 'Hosts',
    'Jewelers', 'Judges', 'Jurors', 'Kings', 'Knights', 'Lawyers', 'Lecturers', 'Librarians',
    'Machinists', 'Masters', 'Mathematicians', 'Mechanics', 'Monologists', 'Musicians', 'Opticians', 'Painters',
    'Parents', 'Patients', 'Pavers', 'Philosophers', 'Photographers', 'Physicians', 'Physicists', 'Pilots',
    'Players', 'Playmakers', 'Plumbers', 'Poets', 'Policemen', 'Princes', 'Princesses', 'Principals',
    'Prisoners', 'Professors', 'Psychologists', 'Publishers', 'Quants', 'Queens', 'Researchers', 'Roofers',
    'Sailors', 'Scholars', 'Scientists', 'Scorers', 'Scribes', 'Secretaries', 'Settlers', 'Sheriffs',
    'Soldiers', 'Strategists', 'Students', 'Surgeons', 'Surveyors', 'Teachers', 'Technicians', 'Therapists',
    'Tourists', 'Traders', 'Veterinarians', 'Violinists', 'Visitors', 'Waiters', 'Warlords', 'Witches']
    

In [46]:
# 88
thing = [
    "Apples", "Amulets", "Arrows", "Axes", "Backpacks", "Batteries", "Belts", "Bells", "Boots",
    "Bolts", "Books", "Boxes", "Bowls", "Bows", "Bracelets", "Bracers", "Brooches", "Buckets",
    "Candles", "Chalices", "Chests", "Cogs", "Coins", "Compasses", "Crates", "Crossbows",
    "Crowbars", "Cups", "Daggers", "Drums", "FishingRods", "Flasks", "Flutes", "Gauntlets",
    "Gems", "Glasses", "Gloves", "Greaves", "Hammers", "Hats", "Helmets", "Horns",
    "Jars", "Keys", "Lanterns", "Lockets", "Maps", "Masks", "Mirrors", "Nets",
    "Necklaces", "Notebooks", "OilFlasks", "Orbs", "Paintings", "Pauldrons", "Plates", "Pliers",
    "Pears", "Pipes", "Pouches", "Potions", "Quills", "Ropes", "Runes", "Sashes", "Satchels",
    "ScrollCases", "Scrolls", "Saws", "Screwdrivers", "Shields", "Shovels", "Spears", "Staffs",
    "Statues", "SwordSheaths", "Swords", "Tablets", "Talismans", "Tongs", "Torches", "Trinkets",
    "Trunks", "Vases", "Vials", "Wands", "Wrenches"
]

In [47]:
# 98
un_pre = [
    'Active', 'Alert', 'Ambitious', 'Artistic', 'Bored', 'Brave', 'Busy', 'Calm',
    'Careless', 'Cautious', 'Charming', 'Cheerful', 'Clever', 'Clumsy', 'Cold',
    'Confident', 'Creative', 'Critical', 'Curious', 'Demanding', 'Determined',
    'Diligent', 'Disorganized', 'Distracted', 'Efficient', 'Elegant', 'Energetic',
    'Experienced', 'Fair', 'Fearless', 'Focused', 'Friendly', 'Funny', 'Generous',
    'Graceful', 'Hardworking', 'Helpful', 'Honest', 'Humble', 'Idealistic',
    'Impatient', 'Junior', 'Kind', 'Late', 'Lazy', 'Loyal', 'Messy', 'Modest',
    'Motivated', 'Naive', 'Nervous', 'Neutral', 'New', 'Old', 'Open', 'Organized',
    'Passionate', 'Patient', 'Picky', 'Polite', 'Pragmatic', 'Proud', 'Punctual',
    'Quiet', 'Realistic', 'Rebellious', 'Relaxed', 'Reliable', 'Reserved', 'Rude',
    'Selfish', 'Senior', 'Serious', 'Short', 'Shy', 'Silent', 'Skilled', 'Slow',
    'Smart', 'Social', 'Strict', 'Strong', 'Stubborn', 'Stylish', 'Talented',
    'Talkative', 'Tall', 'Thoughtful', 'Tired', 'Unfair', 'Unreliable',
    'Unsocial', 'Visionary', 'Warm', 'Weak', 'Wise', 'Witty', 'Young'
]

In [48]:
# 88
bin_pre = [
    'Accompany', 'Accuse', 'Admire', 'Advise', 'Align', 'Approach', 'Argue', 'Assist',
    'Betray', 'Blame', 'Brief', 'Challenge', 'Collaborate', 'Comment', 'Compare', 'Compete',
    'Compliment', 'Confront', 'Consult', 'Contact', 'Convince', 'Criticize', 'Deceive', 'Demand',
    'Discipline', 'Discuss', 'Dismiss', 'Doubt', 'Employ', 'Engage', 'Envy', 'Evaluate',
    'Fire', 'Follow', 'Fund', 'Greet', 'Guide', 'Hate', 'Help', 'Ignore',
    'Inform', 'Instruct', 'Insult', 'Interrupt', 'Invite', 'Involve', 'Judge', 'Know',
    'Lecture', 'Like', 'Listen', 'Manage', 'Mentor', 'Monitor', 'Motivate', 'Negotiate',
    'Notify', 'Observe', 'Oppose', 'Pay', 'Persuade', 'Praise', 'Prefer', 'Protect',
    'Provoke', 'Punish', 'Question', 'Refer', 'Reject', 'Remind', 'Replace', 'Report',
    'Request', 'Respect', 'Reward', 'Schedule', 'Scold', 'Shadow', 'Sponsor', 'Supervise',
    'Support', 'Teach', 'Train', 'Trust', 'Undermine', 'Love', 'Value', 'Warn'
]

In [49]:
# 35
tern_pre = [
    'Allocate', 'Assign', 'Award', 'Bring', 'Contribute',
    'Convey', 'Consign', 'Delegate', 'Deliver', 'Dispatch', 'Distribute', 'Donate',
    'Entrust', 'Explain', 'Forward', 'Furnish', 'Give', 'Grant', 'Hand', 'Introduce',
    'Lend', 'Loan', 'Offer', 'Pass', 'Pay', 'Post', 'Present', 'Provide', 'Recommend',
    'Sell', 'Send', 'Share', 'Show', 'Supply','Transfer'
]

In [50]:
symbol = ['<pad>','<bos>',',','.']
quanti = ['some', 'all', 'Some', 'All']
LO = ['and', 'or', 'implies']
others = ['the', 'a', 'to', 'by', 'are', 'of', 'in', 'which', 'that', 'It', 'it', 'is', 'who', 'day', 'careful', 'inspection', 
          'with', 'without', 'will', 'after', 'After', 'end', 'occasionally', 'effectively', 'regularly', 'timely', 'case', 
         'planning', 'manner', 'great', 'care', 'exception']
predicate = person + thing + un_pre + bin_pre + tern_pre
predicate = [i.lower() for i in predicate]
toks = symbol + quanti + predicate + LO + others
enc_voc_size = len(toks)
tok2idx = tok_to_idx(toks)
idx2tok = idx_to_tok(toks)
src_pad_idx = tok2idx['<pad>']

## Model

In [51]:
class Embeddings(nn.Module):
    def __init__(self, vocb_size, max_len, d_model, device, dropout):
        super().__init__()
        self.we = nn.Embedding(vocb_size, d_model)
        self.pe = nn.Parameter(torch.randn(1, max_len, d_model))
        self.device = device
        self.dropout = nn.Dropout(p=dropout)
        self.scale = math.sqrt(d_model)
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        tok = self.we(x) * self.scale
        pos = self.pe[:, :seq_len, :].to(device=self.device)
        embedding = tok + pos.expand(batch_size, seq_len, -1)
        return self.dropout(embedding)

In [52]:
class Classifier(nn.Module):
    
    def __init__(self, vocb_size, max_len, d_model, n_heads, n_classes, dropout):
        super().__init__()
        self.emb = Embeddings(vocb_size, max_len, d_model, device, dropout)
        self.encoder = TransformerEncoderWithHooks(d_model=d_model, 
                                              nhead=n_heads, 
                                              num_layers=n_layers, 
                                              dim_feedforward=2048, 
                                              dropout=0)
        
        self.classifierhead = nn.Linear(d_model, n_classes)
    
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        emb = self.encoder(self.emb(src), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        return torch.tanh(self.classifierhead(emb[:, 0, :]))

In [53]:
def save_attention_weights_hook_generator(layer_idx):

    def hook(module, input, output):
        if module.training:
            return None
        if not isinstance(output, tuple) or len(output) < 2 or output[1] is None:
            raise ValueError(f'Could not find attention weights in the output of the self-attention module {layer_idx}!')
        
        attention_weights_store[layer_idx] = output[1].detach().cpu()
    return hook

In [54]:
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):

    def _sa_block(self, x, attn_mask, key_padding_mask, **kwargs):
        x, attn_weights = self.self_attn(
            x, x, x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=True,
            average_attn_weights=False,  
            **kwargs
        )
        return self.dropout1(x)

In [55]:
class TransformerEncoderWithHooks(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.num_layers = num_layers
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            layer = MyTransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True   
            )
        
            mha_module = layer.self_attn 

            mha_module.register_forward_hook(
                save_attention_weights_hook_generator(i)) 
            self.layers.append(layer)

        self.norm = nn.LayerNorm(d_model) 

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        output = src
        attention_weights_store.clear()
        for layer in self.layers:
             output = layer(output, src_mask=src_mask,
                            src_key_padding_mask=src_key_padding_mask)

        output = self.norm(output)
        return output

## Training

In [56]:
def train(train_data, model, optimizer, loss_fn, device):  
    model.train()
    train_loss = torch.empty(len(train_data))
    for batch_step, items in tqdm(enumerate(train_data),
                              total=len(train_data),
                              desc='Train',
                              leave=False):
        optimizer.zero_grad()
        inputs = items["input_ids"].to(device)
        gold = items["labels"].to(device)
        mask = (inputs==0)
        logits = model(src=inputs, src_mask=None, src_key_padding_mask=mask)
        loss = loss_fn(logits, gold) 
        loss.backward() 
        optimizer.step() 
        train_loss[batch_step]= loss.item()
    return train_loss.mean().item()

In [57]:
def validate(valid_data, model, loss_fn, device, topk): 
    model.eval()
    macroScore = MulticlassF1Score(num_classes=4, average='macro').to(device)
    perScore = MulticlassF1Score(num_classes=4, average=None).to(device)
    
    all_probs, all_preds, topk_probs  = [], [], []
    valid_loss = torch.empty(len(valid_data))
    valid_accuracy = torch.empty(len(valid_data))
    valid_macroF1 = torch.empty(len(valid_data))
    valid_perF1 = torch.empty(len(valid_data), 4)
    
    
    with torch.no_grad():
        for batch_step, items in tqdm(enumerate(valid_data),
                                total=len(valid_data),
                                desc='Eval',
                                leave=False):
            
            inputs = items["input_ids"].to(device)
            gold = items["labels"].to(device)
            mask = (inputs==0)
            logits = model(src=inputs, src_mask=None, src_key_padding_mask=mask)
            preds = torch.argmax(logits, dim=-1)
            probs = F.softmax(logits, dim=-1)
            topk_val, topk_idx = probs.topk(topk, dim=-1)
            
            all_preds.append(preds)
            topk_probs.append((topk_val, topk_idx))
            
            loss = loss_fn(logits, gold)
            macroF1 = macroScore(preds, gold)
            perF1 = perScore(preds, gold)

            valid_loss[batch_step] = loss.item()
            valid_accuracy[batch_step] = ((preds == gold).sum().item()) / gold.size(0)
            valid_macroF1[batch_step] = macroF1
            valid_perF1[batch_step] = perF1
    return valid_loss.mean().item(), valid_accuracy.mean().item(), valid_macroF1.mean().item(), valid_perF1.mean(dim=0), all_preds, topk_probs

In [59]:
def main(data_name, epoch, best_acc, early_stop, topk):  
    train_data, valid_data = preprocess("FOL2NS.json") 
    train_losses, valid_losses, valid_accuracies = [], [], []
    valid_perF1s, valid_macroF1s = [], []
    lr_string = abs(int(log10(lr)))
    last_improve = 0
    print(data_name)
    for step in tqdm(range(epoch)):
        train_loss = train(train_data=train_data, model=model, optimizer=optimizer, loss_fn=loss_fn, device=device)  
        valid_loss, valid_acc, valid_macroF1, valid_perF1, all_preds, topk_probs = validate(valid_data=valid_data, model=model, loss_fn=loss_fn, device=device, topk=topk)
        print(f'Epoch: {step + 1}, Train Loss: {train_loss:.2f}, Val Loss: {valid_loss:.2f}')
        print(f'Val Accuracy: {valid_acc:.2f}, Val macroF1: {valid_macroF1:.2f}')
        if isinstance(valid_perF1, torch.Tensor):
            valid_perF1_ls = valid_perF1.cpu().tolist()
        else:                    
            valid_perF1_ls = list(valid_perF1)

        print(f'Val_F1 QD=1: {valid_perF1_ls[1]:.2f}, ',
                f'QD=2: {valid_perF1_ls[2]:.2f}, ',
                f'QD=3: {valid_perF1_ls[3]:.2f}')
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_acc)
        valid_perF1s.append(valid_perF1_ls)
        valid_macroF1s.append(valid_macroF1)
        
        if valid_acc > best_acc:
            best_acc = valid_acc
            last_improve = step
            
            for path in glob('*.pt'):
                os.remove(path)
            torch.save({'best_model': model.state_dict()}, 
                        f'{data_name}_epoch{step+1}_lr{lr_string}_{valid_acc:.2f}.pt')
        else:
            if step - last_improve == early_stop:
                print(f'Early stopping: no improvement for {early_stop} epochs.')
                break
        
    plt.plot(train_losses, label='Train Loss')
    plt.plot(valid_losses, label='Valid Loss')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    plt.plot(valid_accuracies, label='Valid Accuracy')
    plt.plot(valid_macroF1s, label='Valid macro_F1score')
    plt.xlabel("Epochs")
    plt.ylabel("Metric")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    for i in range(3):
        per_class = [per[i] for per in valid_perF1s]
        plt.plot(range(1, len(valid_perF1s)+1), per_class, label=f'QD={i+1}')
    
    plt.xlabel('Epochs')
    plt.ylabel('Per-class F1 Score')
    plt.legend()
    plt.grid(True)
    plt.show()
                 
    return all_preds, topk_probs,valid_accuracies, valid_macroF1s, valid_perF1s

In [28]:
def load_checkpoint(path, device,model):
    state = torch.load(path, map_location=device)
    model.load_state_dict(state['best_model'])
    
    model.eval()
    return model

## Function: Attention Heatmap

In [29]:
def check_attention_weights():
    print("\n--- Stored Attention Weights ---")
    if not attention_weights_store:
        print("No attention weights were captured.")
    else:
        for layer_idx, weights in attention_weights_store.items():
            # Expected shape: (batch_size, num_heads, seq_len, seq_len)
            print(f"Layer {layer_idx}: Weights = {weights.shape}")
            # Check that attention weights in the first input sum to 1.
            print(f"Weight row sums = {weights[0].sum(dim=-1).flatten()}")

In [30]:
def attetion_info(sample_idx, all_preds, with_pad=False):
    pad_tok = "<pad>"
    input_id = last_inputs[sample_idx].tolist()
    input_token = detokenization(input_id, idx2tok)
    tokens = [idx2tok[i] for i in input_id]
    if with_pad:
        tok = ["-" if t == pad_tok else t for t in tokens]
    else:
        tok = [t for t in tokens if t != pad_tok]
    label = last_labels[sample_idx].tolist()
    last_preds = all_preds[-1]
    pred = last_preds[sample_idx].tolist()
    print("Input token:", " ".join(tok))
    print("Label:", (int(label)+1))
    print("Prediction:", pred+1)
    
    for layer_idx, weights in attention_weights_store.items():
        batch_size, head, seq_len, _ = weights.shape
        x = np.arange(len(tok))
        
        for h in range(head):
            attn = weights[sample_idx, h, :, :].numpy()
            fig = plt.figure(figsize=(8,8))
            imshow = plt.imshow(attn[:len(tok), :len(tok)], aspect='auto')
            plt.xticks(x, tok, fontsize=10)
            plt.yticks(x, tok, fontsize=10)
            plt.title(f"Sample:{sample_idx}, Layer:{layer_idx}, Head:{h}")
            plt.xlabel("Key Position")
            plt.ylabel("Query Position")
            plt.xticks(rotation=45, ha='right')
            fig.colorbar(imshow)
            
            filename = f"type1_{n_layers}L{n_heads}H_Sample{sample_idx}_L{layer_idx}H{h}.png"   
            plt.savefig(filename, dpi=800, bbox_inches='tight')

In [31]:
def error_case(all_preds, topk_probs):
    last_preds = all_preds[-1]
    top, logit_idx= topk_probs[-1]
    logit_idx = logit_idx+1
    for idx, item in enumerate(last_preds):
        if item != (last_labels[idx]):
            print(f"-----Error{idx}:{all_token[idx]}-----")
            print(f"label:{last_labels[idx]+1}, prediction:{item+1}.")
            print(f"logits:{top[idx].tolist()},logit_idx:{logit_idx[idx].tolist()}")

In [32]:
def correct_case(all_preds, topk_probs):
    last_preds = all_preds[-1]
    top, logit_idx= topk_probs[-1]
    logit_idx = logit_idx
    for idx, item in enumerate(last_preds):
        if item == (last_labels[idx]):
            print(f"-----Correct{idx}:{all_token[idx]}-----")
            print(f"label:{last_labels[idx]+1}, prediction:{item+1}.")
            print(f"logits:{top[idx].tolist()},logit_idx:{logit_idx[idx].tolist()}")

# Run Main

## Parameter setting

In [61]:
# model parameter 
batch_size = 32
max_len = 256
d_model = 256
n_classes = 4
dropout = 0
topk = 4

In [62]:
# optimizer parameter 
early_stop = 5

In [63]:
device = "cuda" # if torch.cuda.is_available() else "cpu"
loss_fn = nn.CrossEntropyLoss()

## Training

In [64]:
data_name = "FOL2NS.json"

In [65]:
lr = 1e-5
epoch = 15
n_layers = 2
n_heads = 2

In [None]:
model = Classifier(vocb_size=enc_voc_size,
                   max_len=max_len,
                   d_model=d_model, 
                   n_heads=n_heads, 
                   n_classes=n_classes, 
                   dropout=dropout).to(device)

optimizer = AdamW(params=model.parameters(), lr=lr)

In [67]:
torch.cuda.empty_cache() 
attention_weights_store = defaultdict(list)

In [None]:
all_preds, topk_probs,valid_accuracies, valid_macroF1s, valid_perF1s = main(data_name, epoch, float("-inf"), early_stop, topk)

In [None]:
def draw_confusion_matrix(loader, all_preds, display_labels=[1,2,3], figsize=(6,6)):
    golds = []
    for batch in loader:
        lbl = batch['labels'] if isinstance(batch, dict) else batch[1]
        golds.append(lbl.detach().cpu().numpy())
    golds = np.concatenate(golds, axis=0)

    preds_list = []
    for p in all_preds:
        if torch.is_tensor(p):
            preds_list.append(p.detach().cpu().numpy())
        else:
            preds_list.append(np.array(p))
    preds = np.concatenate(preds_list, axis=0)

    assert preds.shape[0] == golds.shape[0]

    labels = np.unique(golds)
    cm = confusion_matrix(golds, preds, labels=labels)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=display_labels)

    fig, ax = plt.subplots(figsize=figsize)
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    ax.set_title("Confusion Matrix")
    plt.show()

In [None]:
train_data, valid_data = preprocess(data_name, without_comma=True) 

In [None]:
draw_confusion_matrix(
    loader=valid_data,   
    all_preds=all_preds,  
    display_labels=["QD=1","QD=2","QD=3"]
)

## Attention Heatmaps

In [None]:
all_token = []
train_data, valid_data = preprocess(data_name, without_comma=True) 
last_batch = list(valid_data)[-1]
last_inputs = last_batch["input_ids"]
last_labels = last_batch["labels"]
for i,eg in enumerate(last_inputs.tolist()):
    input_token = detokenization(eg, idx2tok)
    a = ["" if j=="<pad>" else j for j in input_token.split()]
    a = " ".join(a)
    print(i, a)
    all_token.append(a)

In [None]:
error_case(all_preds, topk_probs)

In [None]:
attetion_info(sample_idx=28, all_preds=all_preds)

In [None]:
attetion_info(sample_idx=19, all_preds=all_preds)