In [1]:
import torch
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
from datasets.load import load_from_disk
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, DataCollatorWithPadding
from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score
from transformers import RobertaTokenizer,RobertaForSequenceClassification

In [2]:
class STDataset(Dataset):
    def __init__(self, dataset, indices=None, labels=None) -> None:
        self.indices = indices
        self.data = dataset
        self.labels = labels

    def __getitem__(self, index):
        input_ids = self.data[index]['input_ids']
        attention_mask = self.data[index]['attention_mask']
        label = self.labels[index] if self.labels != None else self.data[index]['labels']
        return {'input_ids':input_ids, 'attention_mask':attention_mask, 'label':label}

    def __len__(self):
        if self.indices: return len(self.indices)
        else: return len(self.data)

In [3]:
test_data  = load_from_disk('data/yahoo_answers/tokenized/test') 
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
criterion = CrossEntropyLoss()
device = 'cuda'
model = RobertaForSequenceClassification.from_pretrained("roberta-base",
            num_labels=10, 
            cache_dir='.cache/model')
model.to(device)

test_data = STDataset(test_data)
test_dataloader = DataLoader(test_data, batch_size=768,collate_fn=data_collator)

acc_metric = MulticlassAccuracy(device=device)
cw_f1_metric = MulticlassF1Score(num_classes=10, average=None, device=device)
cw_acc_metric = MulticlassAccuracy(num_classes=10, average=None, device=device)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classi

In [4]:
def _eval(model, dataloader, progress_bar):
    eval_loss = 0.0
    model.eval()
    for batch in dataloader:
        batch = tuple(input.to(device) for input in batch.data.values())
        
        input_ids, attension_mask = batch[0], batch[1]
        with torch.no_grad():
            output = model(input_ids, attension_mask)
            labels= batch[2]
            loss = criterion(output.logits, labels)

        eval_loss += loss.item()
        acc_metric.update(output.logits, labels)
        cw_f1_metric.update(output.logits, labels)
        cw_acc_metric.update(output.logits, labels)
        
        progress_bar.update(1)
    
    eval_loss = eval_loss / len(dataloader)
    
    accuracy = acc_metric.compute().item()
    cw_f1_score = (cw_f1_metric.compute().cpu().numpy()).tolist()
    cw_accuracy = (cw_acc_metric.compute().cpu().numpy()).tolist()

    acc_metric.reset()
    cw_f1_metric.reset()
    cw_acc_metric.reset()

    return eval_loss, cw_f1_score, accuracy, cw_accuracy

In [5]:
model_files = ['.cache/v3.0/data/yahoo_answers/eval_model.pt','.cache/model/yahoo_answers_test_model_org.pt',
                '.cache/model/yahoo_answers_test_model_sch.pt','.cache/model/yahoo_answers_test_model_sch2.pt']
for model_file in model_files:
    model.load_state_dict(torch.load(model_file,map_location=device))
    eval_pbar = tqdm(desc='Validation ', unit=' batch', colour='white', total= len(test_dataloader))
    ev_loss, _, ev_acc, ev_cw_acc = _eval(model, test_dataloader, eval_pbar)
    eval_pbar.write('File: %s\nTest Loss: %.4f ; Test Acc: %.4f\n\n' \
                                    %(model_file, ev_loss, ev_acc*100))
    eval_pbar.write('CW Acc: '+str(ev_cw_acc))
    eval_pbar.reset()
eval_pbar.close()

Validation :   0%|[37m          [0m| 0/77 [00:00<?, ? batch/s]         

File: .cache/v3.0/data/yahoo_answers/eval_model.pt
Test Loss: 1.0112 ; Test Acc: 66.8877


CW Acc: [0.5116239786148071, 0.7444574236869812, 0.7701736688613892, 0.4702199697494507, 0.8294979333877563, 0.795530378818512, 0.4429609477519989, 0.6596574783325195, 0.7449652552604675, 0.7246254682540894]


Validation :   0%|[37m          [0m| 0/77 [00:00<?, ? batch/s]


KeyboardInterrupt: 