In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd

from utils import *

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def get_acc(y_true, y_pred):
    total = y_true.size(0)
    correct = (y_pred == y_true).sum().item()
    return correct / total

def get_acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

class CustomModelWrapper:
    def __init__(self, model, tokenizer, batch_size=4):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(self.model.parameters()).device
        self.batch_size = batch_size

    def __call__(self, text_input_list):
        out = []
        i = 0
        while i < len(text_input_list):
            batch = text_input_list[i : i + self.batch_size]
            encoding = self.tokenizer(batch, padding=True, truncation=True, max_length=250, return_tensors='pt')
            outputs = self.model(encoding['input_ids'].to(self.device), attention_mask=encoding['attention_mask'].to(self.device))
            # preds = torch.nn.functional.softmax(outputs.logits, dim=1).detach().cpu()
            out.append(outputs.logits.detach().cpu())
            i += self.batch_size
        out = torch.cat(out)
        return out

In [5]:
num_suites = 100
num_tests = 100

datasets = [('AG_NEWS', 'ag-news'), ('SST2', 'SST-2')]
tranforms = ['ORIG', 'INV', 'SIB', 'INVSIB', 'TextMix', 'SentMix', 'WordMix']
MODEL_NAMES = [
    "textattack/bert-base-uncased-SST-2",
    "textattack/roberta-base-SST-2",
    "textattack/bert-base-uncased-ag-news",
    "textattack/roberta-base-ag-news",
]

results = []
for MODEL_NAME in MODEL_NAMES:
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)

    mw = CustomModelWrapper(model, tokenizer)
    
    for d, d_ in datasets:
        
        if d_ not in MODEL_NAME:
            continue
        
        for t in tranforms:
            
            text = npy_load("./assets/" + d + "/" + t + "/text.npy")
            label = npy_load("./assets/" + d + "/" + t + "/label.npy")
            is_soft_label = False
            if len(label.shape) > 1:
                is_soft_label = True
                                
            # print("MODEL_NAME: {}, dataset: {}-{}, is_soft_label: {}".format(MODEL_NAME, d, t, is_soft_label))
            
            accs = []
            for i in range(num_suites):
                
                idx = np.random.choice(np.arange(len(text)), num_tests, replace=False)
                text_sample = text[idx]
                label_sample = label[idx]
                               
                logits = mw([str(x) for x in text_sample])
                y_true = torch.tensor(label_sample)
                
                if is_soft_label:
                    acc = get_acc_at_k(y_true, logits, k=2)
                else:
                    soft_m = torch.softmax(logits, dim=1)
                    y_pred = torch.argmax(soft_m, dim=1)
                    acc = get_acc(y_true, y_pred)
                    
                accs.append(acc)
                
            test_acc = sum(accs) / len(accs)
                
            out = {
                "MODEL_NAME": MODEL_NAME,
                "dataset": d + "-" + t,
                "test_acc": test_acc
            }
            
            print(out)
            results.append(out)
            
df = pd.DataFrame(results)

MODEL_NAME: textattack/bert-base-uncased-SST-2, dataset: SST2-ORIG, is_soft_label: False
{'MODEL_NAME': 'textattack/bert-base-uncased-SST-2', 'dataset': 'SST2-ORIG', 'test_acc': 0.9871999999999994}
MODEL_NAME: textattack/bert-base-uncased-SST-2, dataset: SST2-INV, is_soft_label: False
{'MODEL_NAME': 'textattack/bert-base-uncased-SST-2', 'dataset': 'SST2-INV', 'test_acc': 0.7443000000000005}
MODEL_NAME: textattack/bert-base-uncased-SST-2, dataset: SST2-SIB, is_soft_label: True
{'MODEL_NAME': 'textattack/bert-base-uncased-SST-2', 'dataset': 'SST2-SIB', 'test_acc': 0.7052740745272483}
MODEL_NAME: textattack/bert-base-uncased-SST-2, dataset: SST2-INVSIB, is_soft_label: True
{'MODEL_NAME': 'textattack/bert-base-uncased-SST-2', 'dataset': 'SST2-INVSIB', 'test_acc': 0.7184545924036227}
MODEL_NAME: textattack/bert-base-uncased-SST-2, dataset: SST2-TextMix, is_soft_label: True
{'MODEL_NAME': 'textattack/bert-base-uncased-SST-2', 'dataset': 'SST2-TextMix', 'test_acc': 0.8232}
MODEL_NAME: textatt

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=525.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=150.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501003010.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at textattack/roberta-base-SST-2 were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


MODEL_NAME: textattack/roberta-base-SST-2, dataset: SST2-ORIG, is_soft_label: False
{'MODEL_NAME': 'textattack/roberta-base-SST-2', 'dataset': 'SST2-ORIG', 'test_acc': 0.9625999999999998}
MODEL_NAME: textattack/roberta-base-SST-2, dataset: SST2-INV, is_soft_label: False
{'MODEL_NAME': 'textattack/roberta-base-SST-2', 'dataset': 'SST2-INV', 'test_acc': 0.7603}
MODEL_NAME: textattack/roberta-base-SST-2, dataset: SST2-SIB, is_soft_label: True
{'MODEL_NAME': 'textattack/roberta-base-SST-2', 'dataset': 'SST2-SIB', 'test_acc': 0.6835538111180094}
MODEL_NAME: textattack/roberta-base-SST-2, dataset: SST2-INVSIB, is_soft_label: True
{'MODEL_NAME': 'textattack/roberta-base-SST-2', 'dataset': 'SST2-INVSIB', 'test_acc': 0.7172590575584402}
MODEL_NAME: textattack/roberta-base-SST-2, dataset: SST2-TextMix, is_soft_label: True
{'MODEL_NAME': 'textattack/roberta-base-SST-2', 'dataset': 'SST2-TextMix', 'test_acc': 0.8348000000000002}
MODEL_NAME: textattack/roberta-base-SST-2, dataset: SST2-SentMix, is_

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=754.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798293.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456356.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=239.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501009162.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at textattack/roberta-base-ag-news were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


MODEL_NAME: textattack/roberta-base-ag-news, dataset: AG_NEWS-ORIG, is_soft_label: False
{'MODEL_NAME': 'textattack/roberta-base-ag-news', 'dataset': 'AG_NEWS-ORIG', 'test_acc': 0.985799999999999}
MODEL_NAME: textattack/roberta-base-ag-news, dataset: AG_NEWS-INV, is_soft_label: False
{'MODEL_NAME': 'textattack/roberta-base-ag-news', 'dataset': 'AG_NEWS-INV', 'test_acc': 0.8363999999999998}
MODEL_NAME: textattack/roberta-base-ag-news, dataset: AG_NEWS-SIB, is_soft_label: True
{'MODEL_NAME': 'textattack/roberta-base-ag-news', 'dataset': 'AG_NEWS-SIB', 'test_acc': 0.6100767345357251}
MODEL_NAME: textattack/roberta-base-ag-news, dataset: AG_NEWS-INVSIB, is_soft_label: True
{'MODEL_NAME': 'textattack/roberta-base-ag-news', 'dataset': 'AG_NEWS-INVSIB', 'test_acc': 0.6975635848964127}
MODEL_NAME: textattack/roberta-base-ag-news, dataset: AG_NEWS-TextMix, is_soft_label: True
{'MODEL_NAME': 'textattack/roberta-base-ag-news', 'dataset': 'AG_NEWS-TextMix', 'test_acc': 0.6133232601808546}
MODEL_NA

In [6]:
df

Unnamed: 0,MODEL_NAME,dataset,test_acc
0,textattack/bert-base-uncased-SST-2,SST2-ORIG,0.9872
1,textattack/bert-base-uncased-SST-2,SST2-INV,0.7443
2,textattack/bert-base-uncased-SST-2,SST2-SIB,0.705274
3,textattack/bert-base-uncased-SST-2,SST2-INVSIB,0.718455
4,textattack/bert-base-uncased-SST-2,SST2-TextMix,0.8232
5,textattack/bert-base-uncased-SST-2,SST2-SentMix,0.8134
6,textattack/bert-base-uncased-SST-2,SST2-WordMix,0.7412
7,textattack/roberta-base-SST-2,SST2-ORIG,0.9626
8,textattack/roberta-base-SST-2,SST2-INV,0.7603
9,textattack/roberta-base-SST-2,SST2-SIB,0.683554


In [7]:
df.to_csv('test_models_results.csv')