In [1]:
# !pip install torch==1.13.1+cu116 torchaudio==0.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
# !pip install transformers==4.35.2
# !pip install scikit-learn

In [2]:
import pandas as pd
df = pd.read_csv("../data/AERA02_AptitudeAssessment_Dataset_NLP_cleaned_vi.csv")

In [3]:
import re
import string
def process_text(text):
    text = re.sub("(&#\d+;)", "", text)
    text = re.sub("([\/-])", " ", text)
    text = re.sub("(<.*?>)", "" ,text)
    text = re.sub("(^https?:\/\/\S+)", "", text)
    text = "".join([i for i in text if i not in string.punctuation + "…"])
    text = text.lower()
    return text

def process_corpus(corpus):
    _WORD_SPLIT = re.compile("([.,!?\"/':;)(])")
    def basic_tokenizer(sentence):
        words = []
        for space_separated_fragment in sentence.strip().split():
            words.extend(_WORD_SPLIT.split(space_separated_fragment))
        return [w.lower() for w in words if w != '' and w != ' ' and w not in string.punctuation]
    
    corpus = corpus.replace("\n", " ").split(" ")

In [4]:
df.fillna("", inplace=True)

In [5]:
df["score"] = df["score"].astype("int")

df["review"] = df["review"].apply(process_text)
df["title"] = df["title"].apply(process_text)

## BERT training

In [6]:
import numpy as np
import os
import random
from pathlib import Path
import json

import torch
from tqdm.notebook import tqdm

from transformers import AutoTokenizer, AutoModel
from torch.utils.data import TensorDataset

from transformers import BertForSequenceClassification



class Config():
    seed_val = 17
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    epochs = 8 
    batch_size = 32
    seq_length = 512
    lr = 2e-5
    eps = 1e-8
    pretrained_model = 'bert-base-uncased'
    test_size=0.15
    random_state=42
    add_special_tokens=True 
    return_attention_mask=True 
    pad_to_max_length=True 
    do_lower_case=True
    return_tensors='pt'
    cache_dir="/space/hotel/phit/personal/aera02-aisia/cache"

config = Config()

# params will be saved after training
params = {"seed_val": config.seed_val,
    "device":str(config.device),
    "epochs":config.epochs, 
    "batch_size":config.batch_size,
    "seq_length":config.seq_length,
    "lr":config.lr,
    "eps":config.eps,
    "pretrained_model": config.pretrained_model,
    "test_size":config.test_size,
    "random_state":config.random_state,
    "add_special_tokens":config.add_special_tokens,
    "return_attention_mask":config.return_attention_mask,
    "pad_to_max_length":config.pad_to_max_length,
    "do_lower_case":config.do_lower_case,
    "return_tensors":config.return_tensors,
         }

### Split train/val/test set

In [7]:
#split train test
from sklearn.model_selection import train_test_split

train_df_, test_df = train_test_split(df, 
                                      test_size=0.10, 
                                      random_state=config.random_state, 
                                      stratify=df.score.values)

In [8]:
def set_random_seed(seed_val):
    # set random seed and device
    import random

    device = config.device

    random.seed(config.seed_val)
    np.random.seed(config.seed_val)
    torch.manual_seed(config.seed_val)
    torch.cuda.manual_seed_all(config.seed_val)
    
set_random_seed(config.seed_val)

In [9]:
train_df, val_df = train_test_split(train_df_, 
                                    test_size=0.10, 
                                    random_state=42, 
                            stratify=train_df_.score.values)

In [10]:
train_df.shape, val_df.shape, test_df.shape

((38914, 5), (4324, 5), (4805, 5))

### Load tokenizer

In [11]:
# create tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model, 
                                          do_lower_case=config.do_lower_case)

In [12]:
tokenizer.decode(tokenizer.encode("This is example of tokenizer."))

'[CLS] this is example of tokenizer. [SEP]'

In [13]:
tokenizer.sep_token

'[SEP]'

In [14]:
train_df

Unnamed: 0,score,title,review,title2review,language
14223,5,nghỉ dưởng,hòn tầm phong cảnh rất đẹp bãi biển cát trắng...,"Nghỉ dưởng. Hòn Tầm phong cảnh rất đẹp , bãi b...",vi
34363,5,chuyến đi tuyệt vời,mình và đồng nghiệp có dịp ghé đất quảng và cô...,Chuyến đi tuyệt vời . Mình và đồng nghiệp có d...,vi
40890,5,trải nghiệm cuối tuần,khách sạn sạch sẽ view đẹp nhân viên n...,"Trải nghiệm cuối tuần. - khách sạn sạch sẽ, -...",vi
17361,5,kì nghỉ trải nghiệm rất tuyệt vời,vợ chồng tôi ở khu nghỉ dưỡng dốc lếch này thấ...,Kì nghỉ trải nghiệm rất tuyệt vời. Vợ chồng tô...,vi
40440,5,trải nghiệm tuyệt vời,khách sạn quá đẹp quá rộng rãi và thoáng mát n...,Trải nghiệm tuyệt vời. Khách sạn quá đẹp quá r...,vi
...,...,...,...,...,...
4331,4,phòng rộng rãi,khách sạn 4 sao trang trí theo phong cách hoàn...,"Phòng rộng rãi. Khách sạn 4 sao, trang trí the...",vi
39902,5,banquet team,không gian sang trọng thiết kế độc đáo dịch vụ...,"Banquet team. Không gian sang trọng, thiết kế ...",vi
36763,5,đánh giá 5 sao,villa khá thoải mái rất gần trung tâm đầy đủ t...,"Đánh giá 5 sao. Villa khá thoải mái, rất gần t...",vi
43889,2,thất vọng,thất vọng về thái độ tiếp khách 3h sáng mình q...,"thất vọng. thất vọng về thái độ tiếp khách, 3h...",vi


## Create a CustomDataset class

In [15]:
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

class CustomDataset(Dataset):

    def __init__(self, df, tokenizer):
        self.tokenizer = tokenizer
        self.df = df
        self.df.reset_index(drop=True, inplace=True)
        self.df["encoded"] = self.df.title.fillna("") + f" {tokenizer.sep_token} " + self.df.review.fillna("")
        self.encoded = tokenizer.batch_encode_plus(list(df.encoded.apply(lambda x: x.replace("_"," ")).values), 
                                                max_length=config.seq_length, 
                                                add_special_tokens=config.add_special_tokens, 
                                                return_attention_mask=config.return_attention_mask, 
                                                pad_to_max_length=config.pad_to_max_length,
                                                truncation=True)["input_ids"]
        # Check if the dataframe has the target column
        if hasattr(self.df, "score"):
            self.targets = self.df.score
        else:
            self.targets = None

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        # review = str(self.review[index])
        # review = " ".join(review.split())
        

        return {
            'ids': torch.tensor(self.encoded[index]), 
            # Return None if the targets are None
            'target': None if self.targets is None else torch.tensor(self.targets[index])
        }
        
        return inputs

In [16]:
dataset_train = CustomDataset(train_df, tokenizer)
dataset_val = CustomDataset(val_df, tokenizer)



In [17]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
pad_token_id = tokenizer.pad_token_id

def collate_fn(batch):
    ids = [x["ids"] for x in batch]
    targets = [x["target"] for x in batch]
    max_len = np.max([len(x) for x in ids])
    masks = []
    for i in range(len(ids)):
        if len(ids[i]) < max_len:
            ids[i]= torch.cat((ids[i], torch.tensor([pad_token_id,]*(max_len - len(ids[i])),dtype=torch.long)))
        masks.append(ids[i] != pad_token_id)
    # print(tokenizer.decode(ids[0]))
    # Check if the target is None
    if targets[0] is None:
        # Return only ids and masks
        outputs = {
            "ids": torch.vstack(ids),
            "masks": torch.vstack(masks)
        }
    else:
        # Return ids, masks and target as before
        outputs = {
            "ids": torch.vstack(ids),
            "masks": torch.vstack(masks),
            "target": torch.vstack(targets).view(-1)
        }
    return outputs

dataloader_train = DataLoader(dataset_train, 
                              sampler=RandomSampler(dataset_train), 
                              collate_fn=collate_fn,
                              batch_size=config.batch_size)

dataloader_validation = DataLoader(dataset_val, 
                                   sampler=SequentialSampler(dataset_val), 
                                   collate_fn=collate_fn,
                                   batch_size=config.batch_size)

In [18]:
next(iter(dataloader_train))

{'ids': tensor([[  101,  1047,  3270,  ...,     0,     0,     0],
         [  101,  2008,  3854,  ...,     0,     0,     0],
         [  101, 22775,  5495,  ...,     0,     0,     0],
         ...,
         [  101,  1047,  3270,  ...,     0,     0,     0],
         [  101, 21722,   102,  ...,     0,     0,     0],
         [  101, 18712, 12835,  ...,     0,     0,     0]]),
 'masks': tensor([[ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         ...,
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False]]),
 'target': tensor([5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5,
         4, 5, 5, 5, 5, 5, 5, 5])}

## Build the model

In [19]:
# Save and Load functions
def save_checkpoint(save_path, model, valid_loss):
    if save_path is None:
        return
    
    state_dict = {
                     'model_state_dict': model.state_dict(),
                     'valid_loss': valid_loss
                 }
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_checkpoint(load_path, model):
    if load_path is None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']

def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
    if save_path is None:
        return
    
    state_dict = {
                     'train_loss_list': train_loss_list,
                     'valid_loss_list': valid_loss_list,
                     'global_steps_list': global_steps_list
                 }
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')
   
def load_metrics(load_path):
    if load_path is None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    return state_dict['train_loss_list'], state_dict['valid_loss_list'],state_dict['global_steps_list']


In [20]:
model = BertForSequenceClassification.from_pretrained(config.pretrained_model,
                                                      num_labels=6,
                                                      output_attentions=False,
                                                      output_hidden_states=False,
                                                      cache_dir=config.cache_dir)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import AdamW, get_linear_schedule_with_warmup
import numpy as np

def optimizer_scheduler(model, num_train_steps):
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.001,
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

    opt = AdamW(optimizer_parameters, lr=3e-5)
    sch = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=int(0.05*num_train_steps),
        num_training_steps=num_train_steps,
        last_epoch=-1,
    )
    return opt, sch

class BertTrainer:
    """ A training and evaluation loop for PyTorch models with a BERT like architecture. """
    
    def __init__(
        self, 
        model,
        tokenizer,
        train_dataloader,
        eval_dataloader=None,
        accumulation_steps=5,
        epochs=1,
        lr=5e-04,
        output_dir='./',
        output_filename='model_state_dict.pt',
        save=False,
        tabular=False,
    ):
        """
        Args:
            model: torch.nn.Module: = A PyTorch model with a BERT like architecture,
            tokenizer: = A BERT tokenizer for tokenizing text input,
            train_dataloader: torch.utils.data.DataLoader = 
                A dataloader containing the training data with "text" and "label" keys (optionally a "tabular" key),
            eval_dataloader: torch.utils.data.DataLoader = 
                A dataloader containing the evaluation data with "text" and "label" keys (optionally a "tabular" key),
            epochs: int = An integer representing the number epochs to train,
            lr: float = A float representing the learning rate for the optimizer,
            output_dir: str = A string representing the directory path to save the model,
            output_filename: string = A string representing the name of the file to save in the output directory,
            save: bool = A boolean representing whether or not to save the model,
            tabular: bool = A boolean representing whether or not the BERT model is modified to accept tabular data,
        """
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = model.to(self.device)
        self.tokenizer = tokenizer
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        # num_train_steps = len(train_dataloader) * epochs // accumulation_steps
        # self.optimizer, self.scheduler = optimizer_scheduler(self.model, num_train_steps)
        self.optimizer = AdamW(self.model.parameters(), lr=lr)
        
        self.loss_fn = nn.CrossEntropyLoss()
        self.output_dir = output_dir
        self.output_filename = output_filename
        self.save = save
        self.eval_loss = float('inf')  # tracks the lowest loss so as to only save the best model  
        self.epochs = epochs
        self.epoch_best_model = 0  # tracks which epoch the lowest loss is in so as to only save the best model
        
        
    def train(self, evaluate=False):
        """ Calls the batch iterator to train and optionally evaluate the model."""
        for epoch in range(self.epochs):
            self.iteration(epoch, self.train_dataloader)
            if evaluate and self.eval_dataloader is not None:
                self.iteration(epoch, self.eval_dataloader, train=False)
                
    def evaluate(self):
        """ Calls the batch iterator to evaluate the model."""
        epoch=0
        self.iteration(epoch, self.eval_dataloader, train=False)
    
    def iteration(self, epoch, data_loader, train=True):
        """ Iterates through one epoch of training or evaluation"""
        
        # initialize variables
        loss_accumulated = 0.
        correct_accumulated = 0
        samples_accumulated = 0
        preds_all = []
        labels_all = []
        
        self.model.train() if train else self.model.eval()
        
        # progress bar
        mode = "train" if train else "eval"
        batch_iter = tqdm(
            enumerate(data_loader),
            desc=f"EP ({mode}) {epoch}",
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )
        
        # iterate through batches of the dataset
        for i, batch in batch_iter:

            batch = {key: value.to(self.device) for key, value in batch.items()}

            outputs = self.model(
                input_ids=batch["ids"], 
                attention_mask=batch["masks"],
                token_type_ids=None, 
                labels=batch["target"]
            )

            # calculate loss
            # loss = self.loss_fn(outputs, batch["label"])
            loss = outputs.loss
            logits = outputs.logits
    
            # compute gradient and and update weights
            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            
            # calculate the number of correct predictions
            preds = logits.argmax(dim=-1)
            correct = preds.eq(batch["target"]).sum().item()
            
            # accumulate batch metrics and outputs
            loss_accumulated += loss.item()
            correct_accumulated += correct
            samples_accumulated += len(batch["target"])
            preds_all.append(preds.detach())
            labels_all.append(batch['target'].detach())
            
            batch_iter.set_postfix(loss=loss.item())
        
        # concatenate all batch tensors into one tensor and move to cpu for compatibility with sklearn metrics
        preds_all = torch.cat(preds_all, dim=0).cpu()
        labels_all = torch.cat(labels_all, dim=0).cpu()
        
        # metrics
        accuracy = accuracy_score(labels_all, preds_all)
        precision = precision_score(labels_all, preds_all, average='macro')
        recall = recall_score(labels_all, preds_all, average='macro')
        f1 = f1_score(labels_all, preds_all, average='macro')
        avg_loss_epoch = loss_accumulated / len(data_loader)
        
        # print metrics to console
        print(
            f"""samples={samples_accumulated}, \
    correct={correct_accumulated}, \
    acc={round(accuracy, 4)}, \
    recall={round(recall, 4)}, \
    prec={round(precision,4)}, \
    f1={round(f1, 4)}, \
    loss={round(avg_loss_epoch, 4)}"""
        )    
        
        # save the model if the evaluation loss is lower than the previous best epoch 
        if self.save and not train and avg_loss_epoch < self.eval_loss:
            
            # create directory and filepaths
            dir_path = Path(self.output_dir)
            dir_path.mkdir(parents=True, exist_ok=True)
            file_path = dir_path / f"{self.output_filename}_epoch_{epoch}.pt"
            
            # delete previous best model from hard drive
            if epoch > 0:
                file_path_best_model = dir_path / f"{self.output_filename}_epoch_{self.epoch_best_model}.pt"
                !rm -f $file_path_best_model
            
            # save model
            torch.save({
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict()
            }, file_path)
            
            # update the new best loss and epoch
            self.eval_loss = avg_loss_epoch
            self.epoch_best_model = epoch

In [22]:
trainer = BertTrainer(model, 
                      tokenizer, 
                      dataloader_train, 
                      dataloader_validation, 
                      epochs=config.epochs, 
                      lr=config.lr, 
                      output_dir="../model", 
                      output_filename="model_state_dict.pt", 
                      save=True)

trainer.train(evaluate=True)



EP (train) 0:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=32964,     acc=0.8471,     recall=0.3152,     prec=0.3759,     f1=0.3167,     loss=0.4657


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


EP (eval) 0:   0%|| 0/136 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


samples=4324,     correct=3739,     acc=0.8647,     recall=0.4573,     prec=0.4838,     f1=0.456,     loss=0.3784


EP (train) 1:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=33816,     acc=0.869,     recall=0.4966,     prec=0.5802,     f1=0.4982,     loss=0.3645


EP (eval) 1:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3719,     acc=0.8601,     recall=0.5491,     prec=0.5614,     f1=0.5194,     loss=0.3877


EP (train) 2:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=34226,     acc=0.8795,     recall=0.5394,     prec=0.6047,     f1=0.5491,     loss=0.3302


EP (eval) 2:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3672,     acc=0.8492,     recall=0.5783,     prec=0.5437,     f1=0.5562,     loss=0.394


EP (train) 3:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=34637,     acc=0.8901,     recall=0.6002,     prec=0.6653,     f1=0.6186,     loss=0.2962


EP (eval) 3:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3756,     acc=0.8686,     recall=0.5362,     prec=0.5658,     f1=0.5316,     loss=0.3759


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


To make your changes take effect please reactivate your environment


EP (train) 4:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=35106,     acc=0.9021,     recall=0.6436,     prec=0.702,     f1=0.6654,     loss=0.2623


EP (eval) 4:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3688,     acc=0.8529,     recall=0.5411,     prec=0.5806,     f1=0.5562,     loss=0.4192


EP (train) 5:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=35659,     acc=0.9164,     recall=0.6923,     prec=0.7408,     f1=0.712,     loss=0.2238


EP (eval) 5:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3671,     acc=0.849,     recall=0.5224,     prec=0.5574,     f1=0.5344,     loss=0.4344


EP (train) 6:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=36231,     acc=0.9311,     recall=0.7528,     prec=0.7897,     f1=0.7694,     loss=0.1871


EP (eval) 6:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3689,     acc=0.8531,     recall=0.5489,     prec=0.5768,     f1=0.5596,     loss=0.4868


EP (train) 7:   0%|| 0/1217 [00:00<?, ?it/s]

samples=38914,     correct=36671,     acc=0.9424,     recall=0.7925,     prec=0.8227,     f1=0.8065,     loss=0.1555


EP (eval) 7:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3703,     acc=0.8564,     recall=0.5442,     prec=0.566,     f1=0.5534,     loss=0.5017


In [23]:
trainer.evaluate()

EP (eval) 0:   0%|| 0/136 [00:00<?, ?it/s]

samples=4324,     correct=3703,     acc=0.8564,     recall=0.5442,     prec=0.566,     f1=0.5534,     loss=0.5017


In [None]:
from sklearn.metrics import f1_score

def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels, label_dict):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

In [18]:
def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        
        batch = {key: value.to(config.device) for key, value in batch.items()}


        with torch.no_grad():        
            outputs = model(input_ids=batch["input_ids"],
                       attention_mask=batch["attention_mask"],
                       token_type_ids=batch["token_type_ids"],
                       labels=batch["label"])
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = batch['label'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
        
    # calculate avareage val loss
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals


In [23]:
model.to(config.device)
data_iter = iter(dataloader_validation)

In [71]:
batch = next(data_iter)
batch = {key: value.to(config.device) for key, value in batch.items()}
with torch.no_grad():        
    outputs = model(input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                token_type_ids=batch["token_type_ids"],
                labels=batch["label"])
    logits = outputs[1].detach().cpu().numpy()
print(np.array([logits.argmax(axis=1), batch["label"].detach().cpu().numpy()]))
print(accuracy_per_class(logits, batch["label"].detach().cpu().numpy(), {"A1": 0, "A2": 1, "B1": 2, "B2": 3, "C1": 4, "C2": 5}))

[[5 5 5 5 5 5 5 5 1 5 5 5 5 5 5 5 5 5 5 5 5 5 5 1 5 5 5 5 1 5 1 5]
 [5 5 4 3 5 5 4 5 2 5 5 5 5 5 5 5 5 5 5 5 3 5 5 5 5 5 5 5 1 5 1 5]]
Class: A2
Accuracy: 2/2

Class: B1
Accuracy: 0/1

Class: B2
Accuracy: 0/2

Class: C1
Accuracy: 0/2

Class: C2
Accuracy: 24/25

None
