In [82]:
%cd '/app'

/app


In [83]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
import torch
import yaml 
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import argparse
import re 
from src_dev.loader.checkpoint import load_trained_bag
from src_dev.loader.models.bart_for_polynomial_system import BartForPolynomialSystemGeneration

In [85]:
from src_dev.data.tokenizers import set_tokenizer, set_vocab
from src_dev.loader.data import DataCollator, load_data
load('src_dev/data/symbolic_utils.sage')

In [86]:
dataset_name = 'gb_dataset_n=2_field=RR'
save_dir = f'/app/results2/shape_gb_lex/regression_weights=0.01/{dataset_name}-C'
# save_dir = f'/app/results2/shape_gb_lex/{dataset_name}'
data_dir = f'/app/data_dev/{dataset_name}/data'

In [87]:
bag = load_trained_bag(save_dir, from_checkpoint=False, model_name='bart+')
model = bag['model']
tokenizer = bag['tokenizer']
params = bag['params']

The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.


In [88]:
test_dataset = load_data(data_dir, 
                        encoding='lex.prefix', 
                        return_dataloader=False,
                        extensions=['test'],
                        do_shuffle=[False])
                                        
test_loader = load_data(data_dir, 
                        encoding='lex.prefix',
                        return_dataloader=True, 
                        tokenizer=tokenizer,
                        batch_sizes=[int(20)],
                        extensions=['test'],
                        do_shuffle=[False],
                        continuous_coefficient=True,
                        continuous_exponent=False
                        )

loading ... /app/data_dev/gb_dataset_n=2_field=RR/data.test
loading ... /app/data_dev/gb_dataset_n=2_field=RR/data.test
content of batch_size: 20


In [89]:
batch = next(iter(test_loader))
for k in batch: batch[k] = batch[k].cuda()

labels = batch['labels']
continuous_labels = batch['continuous_labels']

In [90]:
ring = PolynomialRing(RR, 'x', 2, order='lex')

In [171]:
# qfn = lambda x: x.round().clamp(0, 7)
qfn = None
prediction = model.generate(
    input_ids=batch['input_ids'], 
    input_continuous_labels=batch['input_continuous_labels'],
    attention_mask=batch['attention_mask'], 
    continuous_token_ids = [tokenizer.vocab['[C]']],
    max_length = 200,
    quantize_fn = qfn,
    )

# scale = int(2.5)
scale = torch.randn_like(batch['input_continuous_labels'])
prediction2 = model.generate(
    input_ids=batch['input_ids'], 
    input_continuous_labels=batch['input_continuous_labels']*scale,
    attention_mask=batch['attention_mask'], 
    continuous_token_ids = [tokenizer.vocab['[C]']],
    max_length = batch['decoder_input_ids'].shape[-1],
    quantize_fn = qfn,
    )

In [172]:
prediction = model.postprocess_prediction(prediction, tokenizer, skip_special_tokens=True)
prediction['prediction_texts'][:5]

['+ * C1.1693471670150757 x1 x0 [SEP] ^ x1 E3',
 '+ * C-3.700286626815796 x1 + * C-2.3708648681640625 ^ x1 E2 + * C3.2721028327941895 ^ x1 E4 x0 [SEP] + C0.6829630136489868 + * C2.9590835571289062 ^ x1 E3 ^ x1 E5',
 '+ C-4.518379211425781 + * C1.31858229637146 ^ x1 E2 + * C4.336902618408203 ^ x1 E3 + * C-2.4424257278442383 ^ x1 E4 x0 [SEP] + C-2.1039786338806152 + * C-3.958937644958496 ^ x1 E2 + * C3.0583293437957764 ^ x1 E4 ^ x1 E5',
 '+ * C-2.7408194541931152 x1 + * C-2.9938340187072754 ^ x1 E2 + * C-0.058328039944171906 ^ x1 E3 + * C-4.4817047119140625 ^ x1 E4 x0 [SEP] + C-2.460970401763916 + * C-4.70438289642334 x1 + * C4.439737796783447 ^ x1 E4 ^ x1 E5',
 '+ C-0.01693810522556305 + * C-2.703731060028076 x1 + * C0.983980119228363 ^ x1 E2 + * C-2.5385243892669678 ^ x1 E3 + * C4.888523578643799 ^ x1 E4 x0 [SEP] + C4.533261775970459 + * C-2.1941351890563965 x1 + * C0.1081736832857132 ^ x1 E3 + * C3.6623857021331787 ^ x1 E4 ^ x1 E5']

In [173]:
prediction2 = model.postprocess_prediction(prediction2, tokenizer, skip_special_tokens=True)
prediction2['prediction_texts'][:5]

['+ * C-0.4976256787776947 x1 x0 [SEP] ^ x1 E3',
 '+ * C-2.1815810203552246 x1 + * C-0.255695641040802 ^ x1 E2 + * C1.187922477722168 ^ x1 E4 x0 [SEP] + C-0.2968845069408417 + * C3.875363826751709 ^ x1 E3 ^ x1 E5',
 '+ C-4.136988639831543 + * C0.3561578392982483 ^ x1 E2 + * C1.1284410953521729 ^ x1 E3 + * C2.882077217102051 ^ x1 E4 x0 [SEP] + C-0.4171329736709595 + * C0.10180635750293732 x1 + * C3.5139763355255127 ^ x1 E2 ^ x1 E5',
 '+ * C2.6290760040283203 x1 + * C2.8038206100463867 ^ x1 E2 + * C-0.3618961572647095 ^ x1 E3 + * C3.9533467292785645 ^ x1 E4 x0 [SEP] + C0.2772074341773987 + * C-0.5681028366088867 x1 + * C-4.241492748260498 ^ x1 E4 ^ x1 E5',
 '+ C-0.040582045912742615 + * C3.2557995319366455 x1 + * C1.0447475910186768 ^ x1 E2 + * C-0.45094817876815796 ^ x1 E3 + * C3.767113208770752 ^ x1 E4 x0 [SEP] + C3.366997241973877 + * C2.3078360557556152 x1 + * C0.7036225199699402 ^ x1 E3 + * C-1.7481807470321655 ^ x1 E4 ^ x1 E5']

In [204]:


def scale_polynomial_system(model, tokenizer, data_coallator, batch, scales, ring):
    if len(scales) == 1:
        scales = scales * (2 * len(ring.gens()))

    inputs_text = model.decode_continuous_ids(batch['input_ids'], 
                                          batch['input_continuous_labels'], 
                                          tokenizer, 
                                          skip_special_tokens=True, 
                                          quantized=False,
                                          continuous_vocab_size=1)
    
    targets_text = model.decode_continuous_ids(batch['decoder_input_ids'], 
                                          batch['target_continuous_labels'], 
                                          tokenizer, 
                                          skip_special_tokens=True, 
                                          quantized=False,
                                          continuous_vocab_size=1)
    
    Fnews = []
    for intext in inputs_text:
        F = [prefix_to_poly(t, ring) for t in intext.split(' [SEP] ')]
        Fnew = [f * s for f, s in zip(F, scales)]
        Fnews.append(Fnew)
    
    # print(Fnews)
    Fnews_text = []
    for Fnew in Fnews:
        Fnew_text = ' [SEP] '.join([poly_to_prefix(f, ring) for f in Fnew])
        Fnews_text.append(Fnew_text)

    batch_new = [{'input': Fnew_text, 'target': target_text} for Fnew_text, target_text in zip(Fnews_text, targets_text)]
    
    return data_coallator(batch_new)

    

In [226]:
# scales = [1.1, 0.9, 0.9, 1.1]
scales = [2.0]
data_coallator = DataCollator(tokenizer, continuous_coefficient=True, continuous_exponent=False, keep_originals=False)
batch_new = scale_polynomial_system(model, tokenizer, data_coallator, batch, scales, ring)

In [227]:
for k in batch_new: batch_new[k] = batch_new[k].cuda()

prediction_scaled = model.generate(
    input_ids=batch_new['input_ids'], 
    input_continuous_labels=batch_new['input_continuous_labels'],
    attention_mask=batch_new['attention_mask'], 
    continuous_token_ids = [tokenizer.vocab['[C]']],
    max_length = batch_new['decoder_input_ids'].shape[-1],
    quantize_fn = qfn,
    )

prediction_scaled = model.postprocess_prediction(prediction_scaled, tokenizer, skip_special_tokens=True)
prediction_scaled['prediction_texts'][:5]

['+ * C2.2628941535949707 x1 x0 [SEP] ^ x1 E3',
 '+ * C-4.761510848999023 x1 + * C-3.603911876678467 ^ x1 E2 + * C4.1028594970703125 ^ x1 E4 x0 [SEP] + C1.3850313425064087 + * C4.970139026641846 ^ x1 E3 ^ x1 E5',
 '+ C-5.0821990966796875 + * C2.449305772781372 ^ x1 E2 + * C4.91221284866333 ^ x1 E3 + * C-4.489055156707764 ^ x1 E4 x0 [SEP] + C-4.142938137054443 + * C-3.1981992721557617 ^ x1 E2 + * C1.8778154850006104 ^ x1 E3 ^ x1 E5',
 '+ * C-4.42269229888916 x1 + * C-4.188697338104248 ^ x1 E2 + * C0.31319519877433777 ^ x1 E3 + * C-4.710109710693359 ^ x1 E4 x0 [SEP] + C-5.078081130981445 + * C-5.012518405914307 x1 + * C4.701960563659668 ^ x1 E4 ^ x1 E5',
 '+ C0.0007309690117835999 + * C-4.004970550537109 x1 + * C1.6103410720825195 ^ x1 E2 + * C-3.280632734298706 ^ x1 E3 + * C4.948937892913818 ^ x1 E4 x0 [SEP] + C5.113838195800781 + * C-4.147992134094238 x1 + * C0.3020535707473755 ^ x1 E3 + * C4.61102294921875 ^ x1 E4 ^ x1 E5']

In [228]:
prediction['prediction_texts'][:5]

['+ * C1.1693471670150757 x1 x0 [SEP] ^ x1 E3',
 '+ * C-3.700286626815796 x1 + * C-2.3708648681640625 ^ x1 E2 + * C3.2721028327941895 ^ x1 E4 x0 [SEP] + C0.6829630136489868 + * C2.9590835571289062 ^ x1 E3 ^ x1 E5',
 '+ C-4.518379211425781 + * C1.31858229637146 ^ x1 E2 + * C4.336902618408203 ^ x1 E3 + * C-2.4424257278442383 ^ x1 E4 x0 [SEP] + C-2.1039786338806152 + * C-3.958937644958496 ^ x1 E2 + * C3.0583293437957764 ^ x1 E4 ^ x1 E5',
 '+ * C-2.7408194541931152 x1 + * C-2.9938340187072754 ^ x1 E2 + * C-0.058328039944171906 ^ x1 E3 + * C-4.4817047119140625 ^ x1 E4 x0 [SEP] + C-2.460970401763916 + * C-4.70438289642334 x1 + * C4.439737796783447 ^ x1 E4 ^ x1 E5',
 '+ C-0.01693810522556305 + * C-2.703731060028076 x1 + * C0.983980119228363 ^ x1 E2 + * C-2.5385243892669678 ^ x1 E3 + * C4.888523578643799 ^ x1 E4 x0 [SEP] + C4.533261775970459 + * C-2.1941351890563965 x1 + * C0.1081736832857132 ^ x1 E3 + * C3.6623857021331787 ^ x1 E4 ^ x1 E5']

In [144]:
inputs_text = model.decode_continuous_ids(batch['input_ids'], 
                                          batch['input_continuous_labels'], 
                                          tokenizer, 
                                          skip_special_tokens=True, 
                                          quantized=False,
                                          continuous_vocab_size=1)

In [146]:
Fs = []

for i in range(5):
    F = [prefix_to_poly(t, ring) for t in inputs_text[i].split(' [SEP] ')]
    Fs.append(F)

    G1 = [prefix_to_poly(t, ring) for t in prediction['prediction_texts'][i].split(' [SEP] ')]
    G2 = [prefix_to_poly(t, ring) for t in prediction2['prediction_texts'][i].split(' [SEP] ')]

    print(G1)
    print(G2)
    print(G1[0].monomials() == G2[0].monomials())
    print([c1/c2 for c1, c2 in zip(G1[0].coefficients()[1:], G2[0].coefficients()[1:])])
    print()

[x0 + 1.16934716701508*x1, x1^3]
[x0 - 0.110448531806469*x1, x1^3]
True
[-10.5872585890416]

[x0 + 3.27210283279419*x1^4 - 2.37086486816406*x1^2 - 3.70028662681580*x1, x1^5 + 2.95908355712891*x1^3 + 0.682963013648987]
[x0 + 0.385334432125092*x1^4 - 3.36606168746948*x1^2 + 3.75350165367126*x1, x1^5 - 0.727133989334106*x1^3 + 0.184677556157112]
True
[8.49159213400365, 0.704343855904322, -0.985822564696776]

[x0 - 2.44242572784424*x1^4 + 4.33690261840820*x1^3 + 1.31858229637146*x1^2 - 4.51837921142578, x1^5 + 3.05832934379578*x1^4 - 3.95893764495850*x1^2 - 2.10397863388062]
[x0 - 2.06073474884033*x1^4 + 2.76142883300781*x1^3 - 0.114703021943569*x1^2 - 2.21860432624817, x1^5 - 2.05031394958496*x1^4 - 4.49889183044434*x1^2 + 0.136174470186234]
True
[1.18522081952503, 1.57052847662358, -11.4956195053010, 2.03658631598664]

[x0 - 4.48170471191406*x1^4 - 0.0583280399441719*x1^3 - 2.99383401870728*x1^2 - 2.74081945419312*x1, x1^5 + 4.43973779678345*x1^4 - 4.70438289642334*x1 - 2.46097040176392]

In [149]:
Fs[1]

[0.607995390892029*x0^3 + 1.86179280281067*x0^2*x1^4 - 1.36398077011108*x0^2*x1^2 - 2.13551139831543*x0^2*x1 - 0.947197914123535*x0*x1^5 - 2.56432056427002*x0*x1^3 - 0.634114444255829*x0,
 0.164125367999077*x0^4*x1^2 + 0.502581775188446*x0^3*x1^6 - 0.368199884891510*x0^3*x1^4 - 0.576470792293549*x0^3*x1^3 - 0.255691409111023*x0^2*x1^7 - 0.692225694656372*x0^2*x1^5 - 0.171176061034203*x0^2*x1^2 + x0 + 3.06218242645264*x1^4 - 2.24340629577637*x1^2 - 3.51238083839417*x1,
 -0.315730839967728*x0^3*x1^3 - 0.966825485229492*x0^2*x1^7 + 0.708312630653381*x0^2*x1^5 + 1.10896706581116*x0^2*x1^4 - 0.308688610792160*x0^2*x1^2 + 0.491878092288971*x0*x1^8 + 0.386385977268219*x0*x1^6 + 0.692514002323151*x0*x1^4 + 1.41352629661560*x0*x1^3 + 0.796164810657501*x0 + x1^5 + 2.43800187110901*x1^4 + 2.70726990699768*x1^3 - 1.78612112998962*x1^2 - 2.79643392562866*x1 + 0.669463515281677]

In [154]:
f1 = Fs[1][0]

In [163]:
newd = f1.dict()
print(newd)
for k in newd:
    newd[k] = RR.random_element()
print(newd)

{(2, 1): -2.13551139831543, (2, 2): -1.36398077011108, (2, 4): 1.86179280281067, (3, 0): 0.607995390892029, (1, 0): -0.634114444255829, (1, 3): -2.56432056427002, (1, 5): -0.947197914123535}
{(2, 1): -0.204062945508356, (2, 2): 0.576721479931021, (2, 4): -0.292769495854436, (3, 0): -0.0207964777007628, (1, 0): 0.181337454505818, (1, 3): -0.392938865234981, (1, 5): 0.160672140236357}


In [165]:
Fnew = [ring(newd), *Fs[1][1:]]

In [170]:
G = ideal(Fnew).groebner_basis()
G

[1.00000000000000]

In [168]:
Fnew_prefix = [poly_to_prefix(f, ring) for f in Fnew]

In [169]:
Fnew_prefix

['+ * C-0.204062945508356 * ^ x0 E2 x1 + * C0.576721479931021 * ^ x0 E2 ^ x1 E2 + * C-0.292769495854436 * ^ x0 E2 ^ x1 E4 + * C-0.0207964777007628 ^ x0 E3 + * C0.181337454505818 x0 + * C-0.392938865234981 * x0 ^ x1 E3 * C0.160672140236357 * x0 ^ x1 E5',
 '+ * C-3.51238083839417 x1 + * C-2.24340629577637 ^ x1 E2 + * C3.06218242645264 ^ x1 E4 + x0 + * C-0.576470792293549 * ^ x0 E3 ^ x1 E3 + * C-0.368199884891510 * ^ x0 E3 ^ x1 E4 + * C0.502581775188446 * ^ x0 E3 ^ x1 E6 + * C0.164125367999077 * ^ x0 E4 ^ x1 E2 + * C-0.171176061034203 * ^ x0 E2 ^ x1 E2 + * C-0.692225694656372 * ^ x0 E2 ^ x1 E5 * C-0.255691409111023 * ^ x0 E2 ^ x1 E7',
 '+ * C-2.79643392562866 x1 + * C-1.78612112998962 ^ x1 E2 + * C2.43800187110901 ^ x1 E4 + * C0.796164810657501 x0 + * C1.41352629661560 * x0 ^ x1 E3 + * C0.692514002323151 * x0 ^ x1 E4 + * C0.386385977268219 * x0 ^ x1 E6 + * C-0.308688610792160 * ^ x0 E2 ^ x1 E2 + * C1.10896706581116 * ^ x0 E2 ^ x1 E4 + * C0.708312630653381 * ^ x0 E2 ^ x1 E5 + * C-0.9668254

In [33]:
prediction2 = model.postprocess_prediction(prediction2, tokenizer, skip_special_tokens=True)
prediction2['prediction_texts'][:5]

['+ * C-2.741446018218994 x1 x0 [SEP] ^ x1 E3',
 '+ * C-4.753111839294434 x1 + * C-4.727725505828857 ^ x1 E2 + * C-4.273586273193359 ^ x1 E4 x0 [SEP] + C-3.8450419902801514 + * C-4.1074113845825195 ^ x1 E3 ^ x1 E5',
 '+ * C-3.465665102005005 ^ x1 E2 + * C-3.530829668045044 ^ x1 E3 + * C-3.8452279567718506 ^ x1 E4 x0 [SEP] + * C-4.087672710418701 ^ x1 E3 + * C-2.7084951400756836 ^ x1 E4 ^ x1 E5',
 '+ * C-2.912614345550537 x1 + * C-3.794856309890747 ^ x1 E2 + * C-4.114710330963135 ^ x1 E3 + * C-4.533664703369141 ^ x1 E4 x0 [SEP] + * C-3.885742664337158 x1 + * C-3.72438383102417 ^ x1 E3 + * C-4.404922962188721 ^ x1 E4 ^ x1 E5',
 '+ C1.442967414855957 + * C-1.5009735822677612 x1 + * C-1.3417309522628784 ^ x1 E2 + * C-2.045741319656372 ^ x1 E3 + * C-2.121232748031616 ^ x1 E4 x0 [SEP] + * C-3.941521644592285 x1 + * C-3.4928479194641113 ^ x1 E2 + * C-3.551729679107666 ^ x1 E3 + * C-3.571502923965454 ^ x1 E4 ^ x1 E5']

In [22]:
hits = []
for i in range(25):
    a = prediction['prediction_texts'][i]
    b = test_dataset[i]['target']
    hits.append(int(a == b))
    if a != b:
        print(f'[{a == b}]')
        print(f'  {a}')
        print(f'  {b}')

print(f'acc: {sum(hits) / len(hits)}')

[False]
  + * C1.1693470478057861 x1 x0 [SEP] ^ x1 E3
  + * C1.07691876689381 x1 x0 [SEP] ^ x1 E3
[False]
  + * C-3.700287342071533 x1 + * C-2.3708648681640625 ^ x1 E2 + * C3.2721035480499268 ^ x1 E4 x0 [SEP] + C0.6829628348350525 + * C2.9590845108032227 ^ x1 E3 ^ x1 E5
  + * C-3.51238087903323 x1 + * C-2.24340640659216 ^ x1 E2 + * C3.06218241582541 ^ x1 E4 x0 [SEP] + C0.669463495636417 + * C2.70727001592005 ^ x1 E3 ^ x1 E5
[False]
  + C-4.518379211425781 + * C1.31858229637146 ^ x1 E2 + * C4.3369011878967285 ^ x1 E3 + * C-2.4424257278442383 ^ x1 E4 x0 [SEP] + C-2.103977680206299 + * C-3.958937883377075 ^ x1 E2 + * C3.0583298206329346 ^ x1 E4 ^ x1 E5
  + C-4.29603514557275 + * C1.23272103128843 ^ x1 E2 + * C4.13504616530492 ^ x1 E3 + * C-2.29396581090454 ^ x1 E4 x0 [SEP] + C-4.56662269794201 + * C-4.60353974968510 ^ x1 E2 + * C4.18503092803670 ^ x1 E4 ^ x1 E5
[False]
  + * C-2.7408199310302734 x1 + * C-2.9938337802886963 ^ x1 E2 + * C-0.05832803249359131 ^ x1 E3 + * C-4.481706142425537 

In [62]:
def eval_prediction(model, dataloader, tokenizer, use_tqdm=False, steps=None):
    # os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    
    hits_list = []
    acc, tot = 0, 0
    start_t = time()
    with torch.no_grad():
        iterator = tqdm(dataloader) if use_tqdm else dataloader
        for i, batch in enumerate(iterator):
            for k in batch: batch[k] = batch[k].cuda()
            x, y = batch['input_ids'], batch['decoder_input_ids']
            
            # output_ids = model.generate(x, max_length=y.shape[-1], num_beams=num_beams, do_sample=False)
            quantize_fn = lambda x: x.round().clamp(0, 7)
            output = model.generate(input_ids               = batch['input_ids'], 
                                    input_continuous_labels = batch['input_continuous_labels'],
                                    attention_mask          = batch['attention_mask'], 
                                    continuous_token_ids    = [tokenizer.vocab['[C]']],
                                    max_length              = y.shape[-1],
                                    quantize_fn             = quantize_fn,)
            # z_text = model.decode_prediction(prediction, tokenizer, skip_special_tokens=True)
            output_ids = output['prediction']

            pred_texts = model.decode_prediction(output, tokenizer, skip_special_tokens=True)['prediction_texts']
            # pred_texts = tokenizer.batch_decode(output['prediction'].long(), skip_special_tokens=True)
            y_texts = tokenizer.batch_decode(y, skip_special_tokens=True)
            # print(pred_texts)
            # print(y_texts)
            # print()
            # print(output_ids[0])
            # print(x[0])

            # l = min(y.shape[-1], output_ids.shape[-1])
            # hits = torch.all(y[:, :l] == output_ids[:, :l], dim=1).cpu()
            hits = [p == y for p, y in zip(pred_texts, y_texts)]
            
            hits_list.append(hits)
            acc += sum(hits)# .item()
            tot += len(hits)
            
            if steps is not None and i > steps: break

        acc /= tot 
    # os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    runtime = time() - start_t
    return {'acc': acc, 'hits': hits_list, 'num_samples': tot, 'runtime': runtime}



In [55]:
from time import time 
eval_prediction(model, test_loader, tokenizer, use_tqdm=False)

['+ x0 * C4 x1 [SEP] ^ x1 E3', '+ x0 + * C6 ^ x1 E4 * C3 x1 [SEP] + ^ x1 E5 + ^ x1 E3 C4', '+ x0 + * C5 ^ x1 E4 + ^ x1 E3 + * C2 ^ x1 E2 C4 [SEP] + ^ x1 E5 + ^ x1 E4 + ^ x1 E3 + * C4 x1 C4', '+ x0 + * C6 ^ x1 E4 + * C2 ^ x1 E3 + * C3 ^ x1 E2 * C5 x1 [SEP] + ^ x1 E5 + * C5 ^ x1 E4 C3', '+ x0 + * C4 ^ x1 E4 + * C2 ^ x1 E3 + * C5 x1 C1 [SEP] + ^ x1 E5 + * C4 ^ x1 E4 + ^ x1 E3 + * C2 x1 C2', '+ x0 * C2 ^ x1 E2 [SEP] ^ x1 E4', '+ x0 + ^ x1 E4 + * C5 ^ x1 E3 + * C3 ^ x1 E2 * C3 x1 [SEP] + ^ x1 E5 + * C5 ^ x1 E4 + * C5 ^ x1 E3 C2', '+ x0 + ^ x1 E4 + * C4 ^ x1 E3 + * C3 ^ x1 E2 C2 [SEP] + ^ x1 E5 + * C5 ^ x1 E4 + * C4 x1 C2', '+ x0 + * C5 ^ x1 E4 + * C5 ^ x1 E2 + * C5 x1 C4 [SEP] + ^ x1 E5 + * C2 ^ x1 E4 * C6 ^ x1 E2', '+ x0 + * C2 ^ x1 E4 + * C3 ^ x1 E3 + * C3 ^ x1 E2 + x1 C6 [SEP] + ^ x1 E5 + * C5 ^ x1 E3 + * C4 ^ x1 E2 + * C5 x1 C2', '+ x0 + * C3 ^ x1 E3 + ^ x1 E2 * C3 x1 [SEP] + ^ x1 E4 + * C3 ^ x1 E3 + * C2 ^ x1 E2 C3', '+ x0 + * C4 ^ x1 E4 + * C2 ^ x1 E3 + * C3 ^ x1 E2 x1 [SEP] + ^ x1 E5

{'acc': 0.666,
 'hits': [[True,
   True,
   False,
   False,
   True,
   True,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   True,
   True,
   True,
   True,
   True,
   True,
   True,
   True,
   False,
   False,
   True,
   False,
   True,
   False,
   False,
   False,
   True,
   True,
   True,
   True,
   True,
   True,
   True,
   True,
   True,
   False,
   False,
   True,
   True,
   False,
   True,
   True,
   False,
   True,
   False,
   True,
   True,
   True,
   True,
   False,
   False,
   False,
   False,
   False,
   False,
   True,
   True,
   True,
   True,
   True,
   True,
   True,
   False,
   True,
   True,
   True,
   True,
   True,
   True,
   False,
   True,
   True,
   True,
   True,
   False,
   True,
   False,
   True,
   True,
   True,
   True,
   False,
   True,
   True,
   True,
   False,
   True,
   True,
   False,
   True,
   True,
   False,
   False,
   True,
   True,
   False,
   False,
   True,
   False,
   True,
   True,
   

In [63]:
from time import time 
eval_prediction(model, test_loader, tokenizer, use_tqdm=False)

{'acc': 0.049,
 'hits': [[False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   True,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   True,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,
   False,