In [1]:
import torch
from transformers import Trainer, GlueDataset, DataCollatorWithPadding, GlueDataTrainingArguments
from transformers import AlbertTokenizer, ConvbertForSequenceClassification, ConvbertModel
from torch.utils.data.dataloader import DataLoader
import os

data_sub_dir = 'CoLA'

model_dir = 'E:\ConvbertData\glue_models\convbert_12/' + data_sub_dir
albert_model_dir = 'E:/ConvbertData/glue_models/albert_ready/' + data_sub_dir

def get_last_checkpoint(dir_name):
    max_check = -1
    result = None
    for filename in os.listdir(dir_name):
        if 'checkpoint' in filename:
            step = int(filename.split('-')[1])
            if step > max_check:
                max_check = step
                result = filename
    return os.path.join(dir_name, result)

tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = ConvbertForSequenceClassification.from_pretrained(get_last_checkpoint(model_dir))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

data_args = GlueDataTrainingArguments(
    data_dir='E:/ConvbertData/glue_data/' + data_sub_dir,
    task_name='cola'
) 

  from ._conv import register_converters as _register_converters


In [2]:
from transformers import AlbertForSequenceClassification, AutoModelLSTMForSequenceClassification,AutoModelForSequenceClassification

albert_model = AutoModelForSequenceClassification.from_pretrained(get_last_checkpoint(albert_model_dir))
albert_model = albert_model.to(device)

In [3]:
from transformers import default_data_collator

dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

data = DataLoader(
    dataset,
    batch_size=1,
    collate_fn=default_data_collator
)

In [4]:
def prepare_inputs(inputs):
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.to(device)
    return inputs

In [5]:
from collections import defaultdict
count = 0

albert_fail_len = defaultdict(int)
albert_correct_len = defaultdict(int)
convbert_fail_len = defaultdict(int)
convbert_correct_len = defaultdict(int)
avg_len = 0

albert_correct_count = 0
albert_fail_count = 0
convbert_correct_count = 0
convbert_fail_count = 0

for inputs in data:
    count += 1
    tokens = tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
    #print()
    labels = inputs['labels'][0].item()
    #print(inputs['labels'])
    with torch.no_grad():
        outputs = model(**prepare_inputs(inputs))
        conv_logits = outputs[1:][0].argmax()
        
        outputs = albert_model(**prepare_inputs(inputs))
        albert_logits = outputs[1:][0].argmax().item()
        
        ln = len(tokens.split(' '))
        avg_len += ln
        if albert_logits == labels:
            albert_correct_len[ln // 10] += 1
            albert_correct_count += 1
        if albert_logits != labels:
            albert_fail_len[ln // 10] += 1
            albert_fail_count += 1
        if conv_logits == labels:
            convbert_correct_len[ln // 10] += 1
            convbert_correct_count += 1
        if conv_logits != labels:
            convbert_fail_len[ln // 10] += 1
            convbert_fail_count += 1
        if (albert_logits == labels and conv_logits != labels) or (albert_logits != labels and conv_logits == labels):
            pass
            #print(tokens)
            #print('label:{}, albert: {}, convbert: {}'.format(labels, albert_logits, conv_logits))
            
            
def print_hist(corrects, fails):
    for ln, cnt in sorted(fails.items()):
        print(cnt / (corrects[ln] + cnt))
        

print('albert')
print_hist(albert_correct_len, albert_fail_len)
#print(sorted(albert_correct_len.items()))

print('convbert')
print_hist(convbert_correct_len, convbert_fail_len)
#print(sorted(convbert_correct_len.items()))

print('avg')
print(avg_len/count)

albert
0.18414322250639387
0.26359832635983266
0.2727272727272727
convbert
0.309462915601023
0.3891213389121339
0.36363636363636365
avg
7.920421860019175


In [6]:
Cola
albert
[(0, 3), (1, 15), (2, 37), (3, 50), (4, 37), (5, 26), (6, 17), (7, 11), (8, 7), (9, 2), (10, 3), (11, 2), (12, 2), (13, 1)]
convbert
[(0, 3), (1, 29), (2, 71), (3, 77), (4, 65), (5, 37), (6, 23), (7, 17), (8, 9), (9, 2), (10, 4), (11, 1), (12, 4), (13, 1)]
avg
41.830297219558965

NameError: name 'Cola' is not defined