# 訓練模型, 省時間, 複製別人間單的範例即可

In [19]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW, WarmupLinearSchedule

import csv
import json
import numpy as np
import pandas as pd


# 超参数
EPOCHS = 1  # 训练的轮数
BATCH_SIZE = 4  # 批大小
MAX_LEN = 10  # 文本最大长度
LR = 1e-5  # 学习率
WARMUP_STEPS = 100  # 热身步骤
T_TOTAL = 1000  # 总步骤

# pytorch的dataset类 重写getitem,len方法
class Custom_dataset(Dataset):
    def __init__(self, dataset_list):
        self.dataset = dataset_list

    def __getitem__(self, item):
        text = self.dataset[item][0]
        label = self.dataset[item][1]

        return text, label

    def __len__(self):
        return len(self.dataset)


# 加载数据集
def load_dataset(filepath, max_len):
    dataset_list = []
    f = open(filepath, 'r', encoding='utf-8')
    r = csv.reader(f)
    for item in r:
        if r.line_num == 1:
            continue
        dataset_list.append(item)
    
    # 根据max_len参数进行padding
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    for item in dataset_list:
        item[0] = item[0].replace(' ','')
        num = max_len - len(item[0])
        if num < 0:
            item[0] = item[0][:max_len]
            item[0] = tokenizer.encode(item[0]) 
            num_temp = max_len - len(item[0])
            if num_temp > 0:
                for _ in range(num_temp):
                    item[0].append(0)
            # 在开头和结尾加[CLS] [SEP]
            item[0] = [101] + item[0] + [102]
            item[0] = str(item[0])
            continue

        for _ in range(num):
            item[0] = item[0] + '[PAD]'
        item[0] = tokenizer.encode(item[0])
        num_temp = max_len - len(item[0])
        if num_temp > 0:
            for _ in range(num_temp):
                item[0].append(0)
        item[0] = [101] + item[0] + [102]
        item[0] = str(item[0])

    return dataset_list


# 计算每个batch的准确率
def  batch_accuracy(pre, label):
    pre = pre.argmax(dim=1)
    correct = torch.eq(pre, label).sum().float().item()
    accuracy = correct / float(len(label))

    return accuracy


if __name__ == "__main__":

    # 生成数据集以及迭代器
    train_dataset = load_dataset('/home/jacklee/ocr/food_classification/train_data/train_data.csv', max_len = MAX_LEN)  # 7337 * 3
    test_dataset = load_dataset('/home/jacklee/ocr/food_classification/train_data/test_data.csv', max_len = MAX_LEN)  # 7356 * 3
  
    train_cus = Custom_dataset(train_dataset)
    train_loader = DataLoader(dataset=train_cus, batch_size=BATCH_SIZE, shuffle=False)

    # Bert模型以及相关配置
    config = BertConfig.from_pretrained('bert-base-chinese')
    config.num_labels = 2
    model = BertForSequenceClassification(config = config)
    model = BertForSequenceClassification.from_pretrained('bert-base-chinese', config=config)
    model.cuda()

 
    optimizer = AdamW(model.parameters(), lr=LR, correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps = WARMUP_STEPS, t_total = T_TOTAL)

    # optimizer = optim.Adam(model.parameters(), lr=LR)

    model.train()
    print('开始训练...')
    for epoch in range(EPOCHS):
        for text, label in train_loader:
            text_list = list(map(json.loads, text))
            label_list = list(map(json.loads, label))
            
            text_tensor = torch.tensor(text_list).cuda()
            label_tensor = torch.tensor(label_list).cuda()

            outputs = model(text_tensor, labels=label_tensor)
            loss, logits = outputs[:2]
            optimizer.zero_grad()
            loss.backward()
            scheduler.step()
            optimizer.step()

            acc = batch_accuracy(logits, label_tensor)
            print('epoch:{} | acc:{} | loss:{}'.format(epoch, acc, loss))

    torch.save(model.state_dict(), 'bert_cla.ckpt')
    print('保存训练完成的model...')




开始训练...
epoch:0 | acc:1.0 | loss:0.30297398567199707
epoch:0 | acc:1.0 | loss:0.36378350853919983
epoch:0 | acc:1.0 | loss:0.35447001457214355
epoch:0 | acc:0.75 | loss:0.6050809621810913
epoch:0 | acc:0.75 | loss:0.5708495378494263
epoch:0 | acc:1.0 | loss:0.2989423871040344
epoch:0 | acc:1.0 | loss:0.26403796672821045
epoch:0 | acc:1.0 | loss:0.24721144139766693
epoch:0 | acc:0.75 | loss:0.607689619064331
epoch:0 | acc:0.75 | loss:0.562098503112793
epoch:0 | acc:1.0 | loss:0.23412813246250153
epoch:0 | acc:1.0 | loss:0.21824197471141815
epoch:0 | acc:0.75 | loss:0.5723121762275696
epoch:0 | acc:1.0 | loss:0.13467970490455627
epoch:0 | acc:0.75 | loss:0.6469825506210327
epoch:0 | acc:1.0 | loss:0.11735144257545471
epoch:0 | acc:0.5 | loss:0.9911912679672241
epoch:0 | acc:1.0 | loss:0.10627524554729462
epoch:0 | acc:1.0 | loss:0.13061529397964478
epoch:0 | acc:1.0 | loss:0.10724890232086182
epoch:0 | acc:1.0 | loss:0.10682299733161926
epoch:0 | acc:1.0 | loss:0.09847307205200195
epoch:

epoch:0 | acc:1.0 | loss:0.034328848123550415
epoch:0 | acc:1.0 | loss:0.006967723369598389
epoch:0 | acc:1.0 | loss:0.016558855772018433
epoch:0 | acc:0.75 | loss:1.2550219297409058
epoch:0 | acc:1.0 | loss:0.011753469705581665
epoch:0 | acc:1.0 | loss:0.004686117172241211
epoch:0 | acc:1.0 | loss:0.006470322608947754
epoch:0 | acc:1.0 | loss:0.009973883628845215
epoch:0 | acc:1.0 | loss:0.01245260238647461
epoch:0 | acc:1.0 | loss:0.006226718425750732
epoch:0 | acc:1.0 | loss:0.0060836076736450195
epoch:0 | acc:1.0 | loss:0.007462739944458008
epoch:0 | acc:1.0 | loss:0.014298558235168457
epoch:0 | acc:1.0 | loss:0.017199933528900146
epoch:0 | acc:0.75 | loss:1.3312264680862427
epoch:0 | acc:1.0 | loss:0.008185088634490967
epoch:0 | acc:1.0 | loss:0.016142457723617554
epoch:0 | acc:1.0 | loss:0.01286858320236206
epoch:0 | acc:0.75 | loss:0.43040338158607483
epoch:0 | acc:1.0 | loss:0.02402395009994507
epoch:0 | acc:1.0 | loss:0.029601484537124634
epoch:0 | acc:1.0 | loss:0.03676971793

epoch:0 | acc:1.0 | loss:0.034144818782806396
epoch:0 | acc:1.0 | loss:0.015330910682678223
epoch:0 | acc:1.0 | loss:0.0216829776763916
epoch:0 | acc:1.0 | loss:0.019836366176605225
epoch:0 | acc:1.0 | loss:0.018334388732910156
epoch:0 | acc:1.0 | loss:0.027567028999328613
epoch:0 | acc:1.0 | loss:0.012422800064086914
epoch:0 | acc:0.75 | loss:0.8739170432090759
epoch:0 | acc:1.0 | loss:0.014597654342651367
epoch:0 | acc:1.0 | loss:0.02012425661087036
epoch:0 | acc:1.0 | loss:0.01926809549331665
epoch:0 | acc:1.0 | loss:0.020918548107147217
epoch:0 | acc:1.0 | loss:0.01791912317276001
epoch:0 | acc:1.0 | loss:0.015513777732849121
epoch:0 | acc:1.0 | loss:0.02306187152862549
epoch:0 | acc:1.0 | loss:0.019641995429992676
epoch:0 | acc:1.0 | loss:0.01423555612564087
epoch:0 | acc:1.0 | loss:0.013192474842071533
epoch:0 | acc:1.0 | loss:0.011683881282806396
epoch:0 | acc:1.0 | loss:0.020115554332733154
epoch:0 | acc:0.75 | loss:0.8040379285812378
epoch:0 | acc:1.0 | loss:0.0124017596244812

epoch:0 | acc:1.0 | loss:0.023946940898895264
epoch:0 | acc:1.0 | loss:0.01330837607383728
epoch:0 | acc:1.0 | loss:0.006888985633850098
epoch:0 | acc:1.0 | loss:0.010212153196334839
epoch:0 | acc:1.0 | loss:0.00796884298324585
epoch:0 | acc:1.0 | loss:0.009006470441818237
epoch:0 | acc:1.0 | loss:0.004964888095855713
epoch:0 | acc:1.0 | loss:0.03279213607311249
epoch:0 | acc:1.0 | loss:0.008937180042266846
epoch:0 | acc:1.0 | loss:0.0067389607429504395
epoch:0 | acc:1.0 | loss:0.005784451961517334
epoch:0 | acc:1.0 | loss:0.03187626600265503
epoch:0 | acc:1.0 | loss:0.005149126052856445
epoch:0 | acc:1.0 | loss:0.00509798526763916
epoch:0 | acc:1.0 | loss:0.006238400936126709
epoch:0 | acc:1.0 | loss:0.005605101585388184
epoch:0 | acc:0.75 | loss:0.19487716257572174
epoch:0 | acc:1.0 | loss:0.003276646137237549
epoch:0 | acc:1.0 | loss:0.015805423259735107
epoch:0 | acc:1.0 | loss:0.0054999589920043945
epoch:0 | acc:1.0 | loss:0.022080451250076294
epoch:0 | acc:1.0 | loss:0.0087240636

epoch:0 | acc:1.0 | loss:0.00632089376449585
epoch:0 | acc:1.0 | loss:0.01410934329032898
epoch:0 | acc:1.0 | loss:0.009158402681350708
epoch:0 | acc:1.0 | loss:0.008288562297821045
epoch:0 | acc:1.0 | loss:0.017168641090393066
epoch:0 | acc:1.0 | loss:0.008682727813720703
epoch:0 | acc:1.0 | loss:0.021484896540641785
epoch:0 | acc:1.0 | loss:0.007960349321365356
epoch:0 | acc:1.0 | loss:0.007923543453216553
epoch:0 | acc:1.0 | loss:0.009266138076782227
epoch:0 | acc:1.0 | loss:0.06962758302688599
epoch:0 | acc:1.0 | loss:0.006891727447509766
epoch:0 | acc:1.0 | loss:0.007310092449188232
epoch:0 | acc:1.0 | loss:0.007467120885848999
epoch:0 | acc:1.0 | loss:0.006304025650024414
epoch:0 | acc:1.0 | loss:0.008829474449157715
epoch:0 | acc:1.0 | loss:0.006835460662841797
epoch:0 | acc:1.0 | loss:0.009741723537445068
epoch:0 | acc:0.75 | loss:0.25152596831321716
epoch:0 | acc:1.0 | loss:0.005833923816680908
epoch:0 | acc:1.0 | loss:0.008397042751312256
epoch:0 | acc:1.0 | loss:0.0055575370

epoch:0 | acc:1.0 | loss:0.004170894622802734
epoch:0 | acc:1.0 | loss:0.006263077259063721
epoch:0 | acc:1.0 | loss:0.03214263916015625
epoch:0 | acc:1.0 | loss:0.006296396255493164
epoch:0 | acc:1.0 | loss:0.007005035877227783
epoch:0 | acc:1.0 | loss:0.005124390125274658
epoch:0 | acc:1.0 | loss:0.004974305629730225
epoch:0 | acc:1.0 | loss:0.006192207336425781
epoch:0 | acc:1.0 | loss:0.0043354034423828125
epoch:0 | acc:1.0 | loss:0.005789458751678467
epoch:0 | acc:1.0 | loss:0.007323265075683594
epoch:0 | acc:1.0 | loss:0.01233738660812378
epoch:0 | acc:1.0 | loss:0.004482388496398926
epoch:0 | acc:1.0 | loss:0.00512385368347168
epoch:0 | acc:1.0 | loss:0.0081787109375
epoch:0 | acc:1.0 | loss:0.015797197818756104
epoch:0 | acc:1.0 | loss:0.01711246371269226
epoch:0 | acc:1.0 | loss:0.0037508606910705566
epoch:0 | acc:1.0 | loss:0.06084349751472473
epoch:0 | acc:0.75 | loss:0.19948837161064148
epoch:0 | acc:1.0 | loss:0.006966650485992432
epoch:0 | acc:1.0 | loss:0.005126774311065

epoch:0 | acc:1.0 | loss:0.00871288776397705
epoch:0 | acc:1.0 | loss:0.005991041660308838
epoch:0 | acc:1.0 | loss:0.01804909110069275
epoch:0 | acc:1.0 | loss:0.04264882206916809
epoch:0 | acc:1.0 | loss:0.005592942237854004
epoch:0 | acc:1.0 | loss:0.01177254319190979
epoch:0 | acc:1.0 | loss:0.00786679983139038
epoch:0 | acc:1.0 | loss:0.007029682397842407
epoch:0 | acc:1.0 | loss:0.008242547512054443
epoch:0 | acc:1.0 | loss:0.006267249584197998
epoch:0 | acc:1.0 | loss:0.004220902919769287
epoch:0 | acc:1.0 | loss:0.00474625825881958
epoch:0 | acc:1.0 | loss:0.015312105417251587
epoch:0 | acc:1.0 | loss:0.0798504501581192
epoch:0 | acc:1.0 | loss:0.006544679403305054
epoch:0 | acc:1.0 | loss:0.005406439304351807
epoch:0 | acc:1.0 | loss:0.004350781440734863
epoch:0 | acc:1.0 | loss:0.0071266889572143555
epoch:0 | acc:1.0 | loss:0.005962640047073364
epoch:0 | acc:1.0 | loss:0.004163801670074463
epoch:0 | acc:1.0 | loss:0.022902697324752808
epoch:0 | acc:1.0 | loss:0.00648069381713

epoch:0 | acc:1.0 | loss:0.007145822048187256
epoch:0 | acc:1.0 | loss:0.008157134056091309
epoch:0 | acc:1.0 | loss:0.00588458776473999
epoch:0 | acc:1.0 | loss:0.010095208883285522
epoch:0 | acc:1.0 | loss:0.008504748344421387
epoch:0 | acc:1.0 | loss:0.004833042621612549
epoch:0 | acc:1.0 | loss:0.010452061891555786
epoch:0 | acc:1.0 | loss:0.007683277130126953
epoch:0 | acc:1.0 | loss:0.008293837308883667
epoch:0 | acc:1.0 | loss:0.0068441033363342285
epoch:0 | acc:1.0 | loss:0.007345914840698242
epoch:0 | acc:1.0 | loss:0.0035779476165771484
epoch:0 | acc:1.0 | loss:0.007082700729370117
epoch:0 | acc:1.0 | loss:0.004221200942993164
epoch:0 | acc:1.0 | loss:0.006271839141845703
epoch:0 | acc:1.0 | loss:0.006031930446624756
epoch:0 | acc:1.0 | loss:0.006186425685882568
epoch:0 | acc:1.0 | loss:0.009641706943511963
epoch:0 | acc:1.0 | loss:0.012424111366271973
epoch:0 | acc:1.0 | loss:0.003629624843597412
epoch:0 | acc:1.0 | loss:0.004325270652770996
epoch:0 | acc:1.0 | loss:0.012441

epoch:0 | acc:1.0 | loss:0.004750311374664307
epoch:0 | acc:1.0 | loss:0.012369632720947266
epoch:0 | acc:1.0 | loss:0.01349639892578125
epoch:0 | acc:1.0 | loss:0.004656493663787842
epoch:0 | acc:1.0 | loss:0.009236633777618408
epoch:0 | acc:1.0 | loss:0.012907266616821289
epoch:0 | acc:1.0 | loss:0.010356128215789795
epoch:0 | acc:1.0 | loss:0.006725668907165527
epoch:0 | acc:1.0 | loss:0.005110502243041992
epoch:0 | acc:1.0 | loss:0.005281925201416016
epoch:0 | acc:1.0 | loss:0.004195034503936768
epoch:0 | acc:1.0 | loss:0.004372358322143555
epoch:0 | acc:0.75 | loss:0.5707242488861084
epoch:0 | acc:1.0 | loss:0.004922628402709961
epoch:0 | acc:1.0 | loss:0.010473817586898804
epoch:0 | acc:1.0 | loss:0.009856045246124268
epoch:0 | acc:1.0 | loss:0.019057631492614746
epoch:0 | acc:1.0 | loss:0.011466234922409058
epoch:0 | acc:1.0 | loss:0.004692792892456055
epoch:0 | acc:1.0 | loss:0.017953723669052124
epoch:0 | acc:1.0 | loss:0.005402743816375732
epoch:0 | acc:1.0 | loss:0.005451977

epoch:0 | acc:1.0 | loss:0.03928312659263611
epoch:0 | acc:1.0 | loss:0.0037841200828552246
epoch:0 | acc:1.0 | loss:0.006322622299194336
epoch:0 | acc:1.0 | loss:0.007098674774169922
epoch:0 | acc:1.0 | loss:0.006794750690460205
epoch:0 | acc:1.0 | loss:0.006545066833496094
epoch:0 | acc:1.0 | loss:0.005054354667663574
epoch:0 | acc:1.0 | loss:0.014736294746398926
epoch:0 | acc:1.0 | loss:0.006373703479766846
epoch:0 | acc:1.0 | loss:0.00590592622756958
epoch:0 | acc:1.0 | loss:0.005332052707672119
epoch:0 | acc:1.0 | loss:0.1487898826599121
epoch:0 | acc:1.0 | loss:0.004057347774505615
epoch:0 | acc:1.0 | loss:0.00906023383140564
epoch:0 | acc:1.0 | loss:0.00824078917503357
epoch:0 | acc:1.0 | loss:0.005809783935546875
epoch:0 | acc:1.0 | loss:0.017901986837387085
epoch:0 | acc:1.0 | loss:0.009625136852264404
epoch:0 | acc:1.0 | loss:0.005402266979217529
epoch:0 | acc:1.0 | loss:0.018921107053756714
epoch:0 | acc:1.0 | loss:0.013484776020050049
epoch:0 | acc:1.0 | loss:0.077959656715

epoch:0 | acc:1.0 | loss:0.004370927810668945
epoch:0 | acc:1.0 | loss:0.015186846256256104
epoch:0 | acc:1.0 | loss:0.003840804100036621
epoch:0 | acc:1.0 | loss:0.0071833133697509766
epoch:0 | acc:1.0 | loss:0.006442070007324219
epoch:0 | acc:1.0 | loss:0.005362808704376221
epoch:0 | acc:1.0 | loss:0.007661223411560059
epoch:0 | acc:1.0 | loss:0.008265584707260132
epoch:0 | acc:1.0 | loss:0.09641340374946594
epoch:0 | acc:1.0 | loss:0.012890458106994629
epoch:0 | acc:0.75 | loss:0.8117817044258118
epoch:0 | acc:1.0 | loss:0.01274988055229187
epoch:0 | acc:1.0 | loss:0.008775174617767334
epoch:0 | acc:1.0 | loss:0.009134650230407715
epoch:0 | acc:1.0 | loss:0.009875237941741943
epoch:0 | acc:1.0 | loss:0.008067488670349121
epoch:0 | acc:1.0 | loss:0.09230485558509827
epoch:0 | acc:1.0 | loss:0.005258023738861084
epoch:0 | acc:1.0 | loss:0.010145395994186401
epoch:0 | acc:0.75 | loss:0.45865598320961
epoch:0 | acc:1.0 | loss:0.013447016477584839
epoch:0 | acc:1.0 | loss:0.0049009323120

epoch:0 | acc:1.0 | loss:0.005694210529327393
epoch:0 | acc:1.0 | loss:0.017523884773254395
epoch:0 | acc:1.0 | loss:0.005918920040130615
epoch:0 | acc:1.0 | loss:0.005704164505004883
epoch:0 | acc:1.0 | loss:0.005862116813659668
epoch:0 | acc:1.0 | loss:0.013353109359741211
epoch:0 | acc:1.0 | loss:0.010651201009750366
epoch:0 | acc:1.0 | loss:0.005334794521331787
epoch:0 | acc:1.0 | loss:0.1203518658876419
epoch:0 | acc:1.0 | loss:0.005996763706207275
epoch:0 | acc:1.0 | loss:0.008246749639511108
epoch:0 | acc:0.75 | loss:0.18756286799907684
epoch:0 | acc:1.0 | loss:0.0062454938888549805
epoch:0 | acc:1.0 | loss:0.0053749680519104
epoch:0 | acc:1.0 | loss:0.004616677761077881
epoch:0 | acc:1.0 | loss:0.007710367441177368
epoch:0 | acc:1.0 | loss:0.0145263671875
epoch:0 | acc:1.0 | loss:0.008443504571914673
epoch:0 | acc:1.0 | loss:0.005060791969299316
epoch:0 | acc:1.0 | loss:0.011077821254730225
epoch:0 | acc:1.0 | loss:0.004100322723388672
epoch:0 | acc:1.0 | loss:0.008589297533035

epoch:0 | acc:1.0 | loss:0.007481664419174194
epoch:0 | acc:1.0 | loss:0.010939866304397583
epoch:0 | acc:1.0 | loss:0.004538893699645996
epoch:0 | acc:1.0 | loss:0.011417388916015625
epoch:0 | acc:1.0 | loss:0.009789109230041504
epoch:0 | acc:1.0 | loss:0.019019871950149536
epoch:0 | acc:1.0 | loss:0.004886269569396973
epoch:0 | acc:1.0 | loss:0.007198214530944824
epoch:0 | acc:1.0 | loss:0.006366312503814697
epoch:0 | acc:1.0 | loss:0.004511713981628418
epoch:0 | acc:0.75 | loss:0.2643629014492035
epoch:0 | acc:1.0 | loss:0.006287634372711182
epoch:0 | acc:1.0 | loss:0.009476691484451294
epoch:0 | acc:0.75 | loss:0.39758485555648804
epoch:0 | acc:1.0 | loss:0.013752758502960205
epoch:0 | acc:1.0 | loss:0.005014955997467041
epoch:0 | acc:1.0 | loss:0.0044460296630859375
epoch:0 | acc:1.0 | loss:0.07036401331424713
epoch:0 | acc:1.0 | loss:0.007317841053009033
epoch:0 | acc:0.75 | loss:0.9468446373939514
epoch:0 | acc:1.0 | loss:0.004246950149536133
epoch:0 | acc:1.0 | loss:0.012432605

epoch:0 | acc:1.0 | loss:0.012160122394561768
epoch:0 | acc:1.0 | loss:0.011508554220199585
epoch:0 | acc:1.0 | loss:0.009285897016525269
epoch:0 | acc:1.0 | loss:0.007160782814025879
epoch:0 | acc:1.0 | loss:0.013372033834457397
epoch:0 | acc:1.0 | loss:0.005154013633728027
epoch:0 | acc:1.0 | loss:0.005510985851287842
epoch:0 | acc:1.0 | loss:0.008374273777008057
epoch:0 | acc:1.0 | loss:0.006286919116973877
epoch:0 | acc:1.0 | loss:0.007746607065200806
epoch:0 | acc:1.0 | loss:0.005256354808807373
epoch:0 | acc:1.0 | loss:0.004871785640716553
epoch:0 | acc:1.0 | loss:0.007623612880706787
epoch:0 | acc:1.0 | loss:0.011428922414779663
epoch:0 | acc:1.0 | loss:0.013173222541809082
epoch:0 | acc:1.0 | loss:0.007412850856781006
epoch:0 | acc:1.0 | loss:0.018147051334381104
epoch:0 | acc:1.0 | loss:0.004761755466461182
epoch:0 | acc:1.0 | loss:0.007940292358398438
epoch:0 | acc:1.0 | loss:0.0066721439361572266
epoch:0 | acc:1.0 | loss:0.005970597267150879
epoch:0 | acc:1.0 | loss:0.010470

epoch:0 | acc:1.0 | loss:0.008676648139953613
epoch:0 | acc:1.0 | loss:0.016535699367523193
epoch:0 | acc:1.0 | loss:0.006335049867630005
epoch:0 | acc:1.0 | loss:0.019633114337921143
epoch:0 | acc:1.0 | loss:0.010264545679092407
epoch:0 | acc:1.0 | loss:0.007193148136138916
epoch:0 | acc:1.0 | loss:0.008984982967376709
epoch:0 | acc:1.0 | loss:0.013195335865020752
epoch:0 | acc:1.0 | loss:0.007943570613861084
epoch:0 | acc:1.0 | loss:0.010653167963027954
epoch:0 | acc:1.0 | loss:0.006850898265838623
epoch:0 | acc:1.0 | loss:0.008582651615142822
epoch:0 | acc:1.0 | loss:0.009713351726531982
epoch:0 | acc:1.0 | loss:0.010885834693908691
epoch:0 | acc:1.0 | loss:0.08988459408283234
epoch:0 | acc:1.0 | loss:0.005556881427764893
epoch:0 | acc:1.0 | loss:0.013847768306732178
epoch:0 | acc:1.0 | loss:0.00919845700263977
epoch:0 | acc:1.0 | loss:0.004675745964050293
epoch:0 | acc:1.0 | loss:0.006350517272949219
epoch:0 | acc:1.0 | loss:0.17644302546977997
epoch:0 | acc:1.0 | loss:0.0090923011

epoch:0 | acc:1.0 | loss:0.008701205253601074
epoch:0 | acc:1.0 | loss:0.16160091757774353
epoch:0 | acc:1.0 | loss:0.006296873092651367
epoch:0 | acc:1.0 | loss:0.03191611170768738
epoch:0 | acc:1.0 | loss:0.004300832748413086
epoch:0 | acc:1.0 | loss:0.007158100605010986
epoch:0 | acc:1.0 | loss:0.015025436878204346
epoch:0 | acc:1.0 | loss:0.004382133483886719
epoch:0 | acc:1.0 | loss:0.012945950031280518
epoch:0 | acc:1.0 | loss:0.008875936269760132
epoch:0 | acc:1.0 | loss:0.0088539719581604
epoch:0 | acc:1.0 | loss:0.0058785080909729
epoch:0 | acc:1.0 | loss:0.02336445450782776
epoch:0 | acc:1.0 | loss:0.004042327404022217
epoch:0 | acc:1.0 | loss:0.006581008434295654
epoch:0 | acc:1.0 | loss:0.00907185673713684
epoch:0 | acc:1.0 | loss:0.009222686290740967
epoch:0 | acc:1.0 | loss:0.006459712982177734
epoch:0 | acc:1.0 | loss:0.009714126586914062
epoch:0 | acc:1.0 | loss:0.016600847244262695
epoch:0 | acc:1.0 | loss:0.011979132890701294
epoch:0 | acc:1.0 | loss:0.006728112697601

epoch:0 | acc:1.0 | loss:0.007939308881759644
epoch:0 | acc:1.0 | loss:0.006124973297119141
epoch:0 | acc:1.0 | loss:0.009295672178268433
epoch:0 | acc:1.0 | loss:0.004067063331604004
epoch:0 | acc:1.0 | loss:0.009959489107131958
epoch:0 | acc:1.0 | loss:0.005535542964935303
epoch:0 | acc:1.0 | loss:0.010305941104888916
epoch:0 | acc:1.0 | loss:0.008445799350738525
epoch:0 | acc:1.0 | loss:0.02100345492362976
epoch:0 | acc:1.0 | loss:0.008829891681671143
epoch:0 | acc:1.0 | loss:0.005459010601043701
epoch:0 | acc:1.0 | loss:0.0036901235580444336
epoch:0 | acc:1.0 | loss:0.009607642889022827
epoch:0 | acc:1.0 | loss:0.010044485330581665
epoch:0 | acc:1.0 | loss:0.019570916891098022
epoch:0 | acc:1.0 | loss:0.005896151065826416
epoch:0 | acc:1.0 | loss:0.015225976705551147
epoch:0 | acc:1.0 | loss:0.009298890829086304
epoch:0 | acc:1.0 | loss:0.007209718227386475
epoch:0 | acc:1.0 | loss:0.014669984579086304
epoch:0 | acc:1.0 | loss:0.003749668598175049
epoch:0 | acc:1.0 | loss:0.0085999

epoch:0 | acc:1.0 | loss:0.008423566818237305
epoch:0 | acc:1.0 | loss:0.01013079285621643
epoch:0 | acc:1.0 | loss:0.006464362144470215
epoch:0 | acc:0.75 | loss:1.064699649810791
epoch:0 | acc:1.0 | loss:0.009926378726959229
epoch:0 | acc:1.0 | loss:0.004477739334106445
epoch:0 | acc:1.0 | loss:0.00669407844543457
epoch:0 | acc:1.0 | loss:0.005860447883605957
epoch:0 | acc:1.0 | loss:0.0056833624839782715
epoch:0 | acc:1.0 | loss:0.00550156831741333
epoch:0 | acc:1.0 | loss:0.010685741901397705
epoch:0 | acc:1.0 | loss:0.007405400276184082
epoch:0 | acc:1.0 | loss:0.00402069091796875
epoch:0 | acc:1.0 | loss:0.0102730393409729
epoch:0 | acc:1.0 | loss:0.00898289680480957
epoch:0 | acc:1.0 | loss:0.006340324878692627
epoch:0 | acc:1.0 | loss:0.007461249828338623
epoch:0 | acc:1.0 | loss:0.010079383850097656
epoch:0 | acc:1.0 | loss:0.0069196224212646484
epoch:0 | acc:1.0 | loss:0.004564881324768066
epoch:0 | acc:1.0 | loss:0.004017829895019531
epoch:0 | acc:1.0 | loss:0.01107472181320

epoch:0 | acc:1.0 | loss:0.011830180883407593
epoch:0 | acc:1.0 | loss:0.08675637096166611
epoch:0 | acc:1.0 | loss:0.0083121657371521
epoch:0 | acc:1.0 | loss:0.004504501819610596
epoch:0 | acc:1.0 | loss:0.007326364517211914
epoch:0 | acc:1.0 | loss:0.009067147970199585
epoch:0 | acc:0.5 | loss:1.4442167282104492
epoch:0 | acc:1.0 | loss:0.012583702802658081
epoch:0 | acc:1.0 | loss:0.011559218168258667
epoch:0 | acc:1.0 | loss:0.015339463949203491
epoch:0 | acc:1.0 | loss:0.01840338110923767
epoch:0 | acc:1.0 | loss:0.00554347038269043
epoch:0 | acc:1.0 | loss:0.006250739097595215
epoch:0 | acc:1.0 | loss:0.009175717830657959
epoch:0 | acc:1.0 | loss:0.0045539140701293945
epoch:0 | acc:1.0 | loss:0.005230367183685303
epoch:0 | acc:1.0 | loss:0.0071800947189331055
epoch:0 | acc:1.0 | loss:0.009985417127609253
epoch:0 | acc:1.0 | loss:0.0049495697021484375
epoch:0 | acc:0.75 | loss:0.4554944336414337
epoch:0 | acc:1.0 | loss:0.01176270842552185
epoch:0 | acc:1.0 | loss:0.0060399174690

epoch:0 | acc:1.0 | loss:0.0070407092571258545
epoch:0 | acc:1.0 | loss:0.006859570741653442
epoch:0 | acc:1.0 | loss:0.013522356748580933
epoch:0 | acc:1.0 | loss:0.01624104380607605
epoch:0 | acc:1.0 | loss:0.0032393932342529297
epoch:0 | acc:1.0 | loss:0.012339264154434204
epoch:0 | acc:1.0 | loss:0.012525796890258789
epoch:0 | acc:1.0 | loss:0.005599617958068848
epoch:0 | acc:1.0 | loss:0.00492250919342041
epoch:0 | acc:1.0 | loss:0.014203965663909912
epoch:0 | acc:1.0 | loss:0.005622744560241699
epoch:0 | acc:1.0 | loss:0.005006730556488037
epoch:0 | acc:1.0 | loss:0.006281733512878418
epoch:0 | acc:1.0 | loss:0.006376922130584717
epoch:0 | acc:1.0 | loss:0.004205763339996338
epoch:0 | acc:1.0 | loss:0.005929231643676758
epoch:0 | acc:1.0 | loss:0.005448758602142334
epoch:0 | acc:1.0 | loss:0.052671097218990326
epoch:0 | acc:1.0 | loss:0.00658068060874939
epoch:0 | acc:0.75 | loss:1.1186246871948242
epoch:0 | acc:1.0 | loss:0.008690357208251953
epoch:0 | acc:1.0 | loss:0.009974241

epoch:0 | acc:1.0 | loss:0.0104522705078125
epoch:0 | acc:1.0 | loss:0.011742979288101196
epoch:0 | acc:1.0 | loss:0.028476953506469727
epoch:0 | acc:1.0 | loss:0.004827558994293213
epoch:0 | acc:1.0 | loss:0.011460542678833008
epoch:0 | acc:1.0 | loss:0.007378816604614258
epoch:0 | acc:0.75 | loss:0.18573316931724548
epoch:0 | acc:1.0 | loss:0.015569806098937988
epoch:0 | acc:0.75 | loss:0.9768418669700623
epoch:0 | acc:1.0 | loss:0.00804939866065979
epoch:0 | acc:1.0 | loss:0.011866927146911621
epoch:0 | acc:1.0 | loss:0.00550311803817749
epoch:0 | acc:1.0 | loss:0.02095705270767212
epoch:0 | acc:1.0 | loss:0.013025343418121338
epoch:0 | acc:1.0 | loss:0.010349541902542114
epoch:0 | acc:1.0 | loss:0.009685784578323364
epoch:0 | acc:1.0 | loss:0.00773957371711731
epoch:0 | acc:1.0 | loss:0.00600925087928772
epoch:0 | acc:1.0 | loss:0.012992680072784424
epoch:0 | acc:1.0 | loss:0.005210816860198975
epoch:0 | acc:1.0 | loss:0.007799386978149414
epoch:0 | acc:1.0 | loss:0.004326641559600

epoch:0 | acc:1.0 | loss:0.008397936820983887
epoch:0 | acc:1.0 | loss:0.006780028343200684
epoch:0 | acc:1.0 | loss:0.008031100034713745
epoch:0 | acc:1.0 | loss:0.0060999393463134766
epoch:0 | acc:1.0 | loss:0.007135152816772461
epoch:0 | acc:1.0 | loss:0.0053261518478393555
epoch:0 | acc:1.0 | loss:0.009996354579925537
epoch:0 | acc:1.0 | loss:0.0067182183265686035
epoch:0 | acc:1.0 | loss:0.00828695297241211
epoch:0 | acc:1.0 | loss:0.004160642623901367
epoch:0 | acc:1.0 | loss:0.009096413850784302
epoch:0 | acc:1.0 | loss:0.006989151239395142
epoch:0 | acc:0.75 | loss:1.0454161167144775
epoch:0 | acc:1.0 | loss:0.005891203880310059
epoch:0 | acc:1.0 | loss:0.030863791704177856
epoch:0 | acc:1.0 | loss:0.01522219181060791
epoch:0 | acc:1.0 | loss:0.004529237747192383
epoch:0 | acc:1.0 | loss:0.004733443260192871
epoch:0 | acc:1.0 | loss:0.010047942399978638
epoch:0 | acc:0.5 | loss:1.3875243663787842
epoch:0 | acc:1.0 | loss:0.009571969509124756
epoch:0 | acc:1.0 | loss:0.007009774

epoch:0 | acc:1.0 | loss:0.007212996482849121
epoch:0 | acc:1.0 | loss:0.01328134536743164
epoch:0 | acc:1.0 | loss:0.013776898384094238
epoch:0 | acc:1.0 | loss:0.01621478796005249
epoch:0 | acc:1.0 | loss:0.004192531108856201
epoch:0 | acc:1.0 | loss:0.036307573318481445
epoch:0 | acc:1.0 | loss:0.007779359817504883
epoch:0 | acc:1.0 | loss:0.011901795864105225
epoch:0 | acc:1.0 | loss:0.004958212375640869
epoch:0 | acc:1.0 | loss:0.012614786624908447
epoch:0 | acc:1.0 | loss:0.005038261413574219
epoch:0 | acc:1.0 | loss:0.006046950817108154
epoch:0 | acc:1.0 | loss:0.008627653121948242
epoch:0 | acc:1.0 | loss:0.009903192520141602
epoch:0 | acc:1.0 | loss:0.02886795997619629
epoch:0 | acc:1.0 | loss:0.008565545082092285
epoch:0 | acc:1.0 | loss:0.0057204365730285645
epoch:0 | acc:1.0 | loss:0.010759979486465454
epoch:0 | acc:0.75 | loss:1.0771578550338745
epoch:0 | acc:1.0 | loss:0.006088197231292725
epoch:0 | acc:1.0 | loss:0.01414594054222107
epoch:0 | acc:1.0 | loss:0.00728130340

epoch:0 | acc:1.0 | loss:0.003855466842651367
epoch:0 | acc:1.0 | loss:0.007031261920928955
epoch:0 | acc:1.0 | loss:0.009644865989685059
epoch:0 | acc:1.0 | loss:0.003748178482055664
epoch:0 | acc:1.0 | loss:0.004486143589019775
epoch:0 | acc:1.0 | loss:0.030641108751296997
epoch:0 | acc:1.0 | loss:0.16491656005382538
epoch:0 | acc:1.0 | loss:0.004357814788818359
epoch:0 | acc:1.0 | loss:0.009966164827346802
epoch:0 | acc:1.0 | loss:0.006394743919372559
epoch:0 | acc:0.75 | loss:0.4195846617221832
epoch:0 | acc:1.0 | loss:0.006606101989746094
epoch:0 | acc:1.0 | loss:0.1149234026670456
epoch:0 | acc:0.75 | loss:0.7293254733085632
epoch:0 | acc:0.75 | loss:0.8555155992507935
epoch:0 | acc:1.0 | loss:0.009013324975967407
epoch:0 | acc:1.0 | loss:0.004112958908081055
epoch:0 | acc:1.0 | loss:0.010915935039520264
epoch:0 | acc:1.0 | loss:0.004182755947113037
epoch:0 | acc:1.0 | loss:0.020718246698379517
epoch:0 | acc:1.0 | loss:0.006284773349761963
epoch:0 | acc:1.0 | loss:0.0078472495079

epoch:0 | acc:1.0 | loss:0.01177486777305603
epoch:0 | acc:1.0 | loss:0.009018093347549438
epoch:0 | acc:1.0 | loss:0.007220089435577393
epoch:0 | acc:1.0 | loss:0.00897100567817688
epoch:0 | acc:1.0 | loss:0.00607675313949585
epoch:0 | acc:1.0 | loss:0.005281984806060791
epoch:0 | acc:1.0 | loss:0.006032407283782959
epoch:0 | acc:0.75 | loss:0.6167954206466675
epoch:0 | acc:1.0 | loss:0.010991960763931274
epoch:0 | acc:1.0 | loss:0.009065002202987671
epoch:0 | acc:1.0 | loss:0.009862303733825684
epoch:0 | acc:1.0 | loss:0.008957624435424805
epoch:0 | acc:1.0 | loss:0.018886923789978027
epoch:0 | acc:1.0 | loss:0.02107909321784973
epoch:0 | acc:1.0 | loss:0.007242858409881592
epoch:0 | acc:1.0 | loss:0.006920158863067627
epoch:0 | acc:1.0 | loss:0.013100147247314453
epoch:0 | acc:1.0 | loss:0.005887746810913086
epoch:0 | acc:1.0 | loss:0.010458797216415405
epoch:0 | acc:1.0 | loss:0.005105197429656982
epoch:0 | acc:1.0 | loss:0.026780515909194946
epoch:0 | acc:1.0 | loss:0.005708932876

epoch:0 | acc:1.0 | loss:0.004068255424499512
epoch:0 | acc:1.0 | loss:0.006596207618713379
epoch:0 | acc:1.0 | loss:0.01110878586769104
epoch:0 | acc:1.0 | loss:0.018574774265289307
epoch:0 | acc:1.0 | loss:0.0047293901443481445
epoch:0 | acc:1.0 | loss:0.004323780536651611
epoch:0 | acc:1.0 | loss:0.006562232971191406
epoch:0 | acc:1.0 | loss:0.0037909746170043945
epoch:0 | acc:1.0 | loss:0.004332900047302246
epoch:0 | acc:1.0 | loss:0.00486069917678833
epoch:0 | acc:1.0 | loss:0.007985800504684448
epoch:0 | acc:1.0 | loss:0.004725217819213867
epoch:0 | acc:1.0 | loss:0.005504786968231201
epoch:0 | acc:1.0 | loss:0.00852280855178833
epoch:0 | acc:1.0 | loss:0.017870157957077026
epoch:0 | acc:1.0 | loss:0.009680509567260742
epoch:0 | acc:1.0 | loss:0.016098707914352417
epoch:0 | acc:1.0 | loss:0.030789881944656372
epoch:0 | acc:1.0 | loss:0.004774630069732666
epoch:0 | acc:1.0 | loss:0.008902162313461304
epoch:0 | acc:1.0 | loss:0.007097065448760986
epoch:0 | acc:1.0 | loss:0.01886969

epoch:0 | acc:1.0 | loss:0.005197346210479736
epoch:0 | acc:1.0 | loss:0.16800856590270996
epoch:0 | acc:1.0 | loss:0.04232838749885559
epoch:0 | acc:1.0 | loss:0.0052831172943115234
epoch:0 | acc:1.0 | loss:0.008033335208892822
epoch:0 | acc:1.0 | loss:0.0037105679512023926
epoch:0 | acc:1.0 | loss:0.0055852532386779785
epoch:0 | acc:1.0 | loss:0.008139342069625854
epoch:0 | acc:1.0 | loss:0.008014827966690063
epoch:0 | acc:1.0 | loss:0.00840449333190918
epoch:0 | acc:1.0 | loss:0.0076351165771484375
epoch:0 | acc:1.0 | loss:0.009659230709075928
epoch:0 | acc:1.0 | loss:0.005749762058258057
epoch:0 | acc:1.0 | loss:0.004809379577636719
epoch:0 | acc:1.0 | loss:0.006838202476501465
epoch:0 | acc:1.0 | loss:0.006309151649475098
epoch:0 | acc:1.0 | loss:0.014390379190444946
epoch:0 | acc:1.0 | loss:0.011154770851135254
epoch:0 | acc:1.0 | loss:0.0077263712882995605
epoch:0 | acc:1.0 | loss:0.008360475301742554
epoch:0 | acc:1.0 | loss:0.024320513010025024
epoch:0 | acc:1.0 | loss:0.00690

epoch:0 | acc:1.0 | loss:0.007468074560165405
epoch:0 | acc:1.0 | loss:0.004643738269805908
epoch:0 | acc:1.0 | loss:0.028866559267044067
epoch:0 | acc:0.75 | loss:0.7025078535079956
epoch:0 | acc:1.0 | loss:0.006461977958679199
epoch:0 | acc:1.0 | loss:0.004318535327911377
epoch:0 | acc:1.0 | loss:0.015159934759140015
epoch:0 | acc:1.0 | loss:0.013144195079803467
epoch:0 | acc:1.0 | loss:0.007957756519317627
epoch:0 | acc:1.0 | loss:0.008912861347198486
epoch:0 | acc:1.0 | loss:0.00440371036529541
epoch:0 | acc:1.0 | loss:0.003987491130828857
epoch:0 | acc:1.0 | loss:0.00639110803604126
epoch:0 | acc:1.0 | loss:0.004339039325714111
epoch:0 | acc:1.0 | loss:0.019652187824249268
epoch:0 | acc:1.0 | loss:0.006561458110809326
epoch:0 | acc:1.0 | loss:0.007726907730102539
epoch:0 | acc:1.0 | loss:0.012129902839660645
epoch:0 | acc:1.0 | loss:0.004225492477416992
epoch:0 | acc:1.0 | loss:0.007242560386657715
epoch:0 | acc:1.0 | loss:0.00395125150680542
epoch:0 | acc:1.0 | loss:0.01000970602

epoch:0 | acc:1.0 | loss:0.004758596420288086
epoch:0 | acc:1.0 | loss:0.0049304962158203125
epoch:0 | acc:1.0 | loss:0.03215564787387848
epoch:0 | acc:1.0 | loss:0.09762651473283768
epoch:0 | acc:1.0 | loss:0.005205750465393066
epoch:0 | acc:1.0 | loss:0.0077065229415893555
epoch:0 | acc:1.0 | loss:0.01020166277885437
epoch:0 | acc:1.0 | loss:0.10824349522590637
epoch:0 | acc:1.0 | loss:0.006628930568695068
epoch:0 | acc:1.0 | loss:0.007532298564910889
epoch:0 | acc:1.0 | loss:0.006157219409942627
epoch:0 | acc:1.0 | loss:0.00904381275177002
epoch:0 | acc:1.0 | loss:0.007460176944732666
epoch:0 | acc:1.0 | loss:0.013012230396270752
epoch:0 | acc:1.0 | loss:0.005010545253753662
epoch:0 | acc:1.0 | loss:0.009113818407058716
epoch:0 | acc:1.0 | loss:0.0052277445793151855
epoch:0 | acc:0.75 | loss:1.0032687187194824
epoch:0 | acc:0.75 | loss:1.0969860553741455
epoch:0 | acc:1.0 | loss:0.011956125497817993
epoch:0 | acc:1.0 | loss:0.0050778985023498535
epoch:0 | acc:0.75 | loss:0.961865544

epoch:0 | acc:1.0 | loss:0.014168530702590942
epoch:0 | acc:0.75 | loss:0.8547548651695251
epoch:0 | acc:1.0 | loss:0.009477883577346802
epoch:0 | acc:1.0 | loss:0.0042150020599365234
epoch:0 | acc:1.0 | loss:0.0049623847007751465
epoch:0 | acc:1.0 | loss:0.006120502948760986
epoch:0 | acc:0.75 | loss:0.9840179085731506
epoch:0 | acc:1.0 | loss:0.010898739099502563
epoch:0 | acc:1.0 | loss:0.005217790603637695
epoch:0 | acc:1.0 | loss:0.0032100677490234375
epoch:0 | acc:1.0 | loss:0.018144726753234863
epoch:0 | acc:1.0 | loss:0.010362565517425537
epoch:0 | acc:1.0 | loss:0.00663033127784729
epoch:0 | acc:1.0 | loss:0.01077917218208313
epoch:0 | acc:0.75 | loss:1.1223700046539307
epoch:0 | acc:1.0 | loss:0.00610804557800293
epoch:0 | acc:1.0 | loss:0.006352782249450684
epoch:0 | acc:1.0 | loss:0.007063448429107666
epoch:0 | acc:1.0 | loss:0.005942583084106445
epoch:0 | acc:1.0 | loss:0.006690263748168945
epoch:0 | acc:1.0 | loss:0.005297064781188965
epoch:0 | acc:1.0 | loss:0.0079087615

# 載入訓練好的模型並測試結果

In [20]:
model.load_state_dict(torch.load('bert_cla.ckpt'))

print('开始测试...')
model.eval()
test_result = []
for item in test_dataset:

    text_list = list(json.loads(item[0]))
    text_tensor = torch.tensor(text_list).unsqueeze(0).cuda()

    with torch.no_grad():

        # print('list', text_list)
        # print('tensor', text_tensor)
        # print('tensor.shape', text_tensor.shape)
        outputs = model(text_tensor, labels=None)

        print(outputs[0])

        pre = outputs[0].argmax(dim=1)
        test_result.append([item[0], pre.item()])

开始测试...
tensor([[ 3.0615, -2.7079]], device='cuda:0')
tensor([[-1.9948,  1.4176]], device='cuda:0')
tensor([[ 2.7601, -2.4883]], device='cuda:0')
tensor([[ 1.7807, -1.2395]], device='cuda:0')
tensor([[ 2.6653, -2.0854]], device='cuda:0')
tensor([[-2.2404,  1.9680]], device='cuda:0')
tensor([[-2.2323,  1.9231]], device='cuda:0')
tensor([[ 2.9455, -2.7364]], device='cuda:0')
tensor([[ 2.8297, -2.5776]], device='cuda:0')
tensor([[ 2.4902, -1.8415]], device='cuda:0')
tensor([[ 3.0042, -2.7335]], device='cuda:0')
tensor([[ 2.9092, -2.6257]], device='cuda:0')
tensor([[ 2.9912, -2.7522]], device='cuda:0')
tensor([[ 2.9797, -2.5760]], device='cuda:0')
tensor([[ 2.8414, -2.5982]], device='cuda:0')
tensor([[ 2.9830, -2.6804]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.8223, -2.5544]], device='cuda:0')
tensor([[ 2.6408, -2.5349]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 2.8784, -2.2536]

tensor([[ 2.9718, -2.6257]], device='cuda:0')
tensor([[-2.2441,  1.9230]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 2.7475, -2.4376]], device='cuda:0')
tensor([[ 2.7927, -2.3282]], device='cuda:0')
tensor([[ 0.4549, -0.4818]], device='cuda:0')
tensor([[ 3.0394, -2.5635]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9330, -2.6665]], device='cuda:0')
tensor([[ 3.0476, -2.7428]], device='cuda:0')
tensor([[ 3.0779, -2.7099]], device='cuda:0')
tensor([[ 2.9779, -2.7036]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9034, -2.4493]], device='cuda:0')
tensor([[ 2.9641, -2.6142]], device='cuda:0')
tensor([[ 2.8225, -2.4341]], device='cuda:0')
tensor([[ 3.0741, -2.7194]], device='cuda:0')
tensor([[ 2.7601, -2.4883]], device='cuda:0')
tensor([[-2.2905,  1.9915]], device='cuda:0')
tensor([[ 2.9967, -2.5391]], devic

tensor([[-2.1584,  1.9090]], device='cuda:0')
tensor([[-2.2356,  1.9793]], device='cuda:0')
tensor([[ 2.7344, -2.2938]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[-0.0753, -0.1548]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9348, -2.7138]], device='cuda:0')
tensor([[ 2.7436, -2.5451]], device='cuda:0')
tensor([[ 2.9339, -2.6246]], device='cuda:0')
tensor([[ 2.9935, -2.7252]], device='cuda:0')
tensor([[ 2.7212, -2.4576]], device='cuda:0')
tensor([[ 2.9491, -2.7330]], device='cuda:0')
tensor([[ 3.0209, -2.7351]], device='cuda:0')
tensor([[ 3.0447, -2.6782]], device='cuda:0')
tensor([[-1.8865,  1.4318]], device='cuda:0')
tensor([[ 2.9308, -2.7652]], device='cuda:0')
tensor([[ 2.5979, -2.3179]], device='cuda:0')
tensor([[ 1.1333, -0.7633]], device='cuda:0')
tensor([[ 3.0241, -2.7331]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 2.9144, -2.6407]], device='cuda:0')
tensor([[ 2.9145, -2.4813]], devic

tensor([[ 2.8116, -2.4789]], device='cuda:0')
tensor([[-2.1996,  1.6818]], device='cuda:0')
tensor([[ 2.9896, -2.6582]], device='cuda:0')
tensor([[ 2.6401, -2.4508]], device='cuda:0')
tensor([[ 1.9730, -1.1346]], device='cuda:0')
tensor([[ 2.9731, -2.5871]], device='cuda:0')
tensor([[ 2.8840, -2.6359]], device='cuda:0')
tensor([[ 1.4706, -1.2306]], device='cuda:0')
tensor([[ 2.6229, -2.3125]], device='cuda:0')
tensor([[ 2.9967, -2.6345]], device='cuda:0')
tensor([[ 2.9026, -2.5535]], device='cuda:0')
tensor([[ 3.0440, -2.7570]], device='cuda:0')
tensor([[-2.3225,  1.9774]], device='cuda:0')
tensor([[ 2.8403, -2.5871]], device='cuda:0')
tensor([[ 3.0665, -2.7364]], device='cuda:0')
tensor([[ 2.6473, -2.0137]], device='cuda:0')
tensor([[ 2.9728, -2.6542]], device='cuda:0')
tensor([[ 2.9858, -2.7457]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0461, -2.6425]], device='cuda:0')
tensor([[ 2.7643, -2.6408]], device='cuda:0')
tensor([[-2.2042,  1.9679]], devic

tensor([[ 2.9440, -2.6454]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.8030, -2.4404]], device='cuda:0')
tensor([[ 2.8795, -2.6087]], device='cuda:0')
tensor([[ 2.9981, -2.6866]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[-1.9239,  1.3650]], device='cuda:0')
tensor([[ 3.0262, -2.7520]], device='cuda:0')
tensor([[ 2.6842, -2.4571]], device='cuda:0')
tensor([[ 2.9360, -2.5559]], device='cuda:0')
tensor([[ 2.8715, -2.1979]], device='cuda:0')
tensor([[ 2.7436, -2.5451]], device='cuda:0')
tensor([[ 2.8095, -2.4349]], device='cuda:0')
tensor([[ 3.0429, -2.7147]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[-2.1978,  1.8748]], device='cuda:0')
tensor([[-0.4610,  0.2364]], device='cuda:0')
tensor([[-2.2861,  1.8415]], device='cuda:0')
tensor([[ 3.0499, -2.7255]], device='cuda:0')
tensor([[ 3.0647, -2.6768]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.7708, -2.4477]], devic

tensor([[ 2.6831, -2.3368]], device='cuda:0')
tensor([[ 3.0418, -2.7297]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 2.8521, -2.6672]], device='cuda:0')
tensor([[ 2.9026, -2.5535]], device='cuda:0')
tensor([[ 2.7502, -2.6243]], device='cuda:0')
tensor([[ 2.9822, -2.6423]], device='cuda:0')
tensor([[ 3.0370, -2.7278]], device='cuda:0')
tensor([[ 2.9733, -2.6947]], device='cuda:0')
tensor([[-2.2662,  1.8748]], device='cuda:0')
tensor([[ 2.9632, -2.7188]], device='cuda:0')
tensor([[ 2.6076, -2.3331]], device='cuda:0')
tensor([[ 2.9360, -2.5559]], device='cuda:0')
tensor([[-0.0862, -0.1439]], device='cuda:0')
tensor([[ 3.0362, -2.7389]], device='cuda:0')
tensor([[-2.2990,  1.7417]], device='cuda:0')
tensor([[ 2.6401, -2.4508]], device='cuda:0')
tensor([[ 3.0528, -2.7234]], device='cuda:0')
tensor([[ 2.9682, -2.7685]], device='cuda:0')
tensor([[ 2.8629, -2.5320]], device='cuda:0')
tensor([[-2.1674,  1.8498]], device='cuda:0')
tensor([[ 2.9026, -2.5535]], devic

tensor([[ 2.8715, -2.1979]], device='cuda:0')
tensor([[ 3.0202, -2.7241]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 2.8297, -2.5776]], device='cuda:0')
tensor([[ 2.8032, -2.2525]], device='cuda:0')
tensor([[ 2.9267, -2.6588]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[-2.1783,  1.8365]], device='cuda:0')
tensor([[-2.2588,  1.9666]], device='cuda:0')
tensor([[ 2.9849, -2.6302]], device='cuda:0')
tensor([[-2.1872,  1.7512]], device='cuda:0')
tensor([[ 2.7597, -2.4551]], device='cuda:0')
tensor([[ 3.0563, -2.6988]], device='cuda:0')
tensor([[ 2.8048, -2.4370]], device='cuda:0')
tensor([[ 3.0362, -2.7389]], device='cuda:0')
tensor([[ 2.6525, -2.3644]], device='cuda:0')
tensor([[ 3.0527, -2.7475]], device='cuda:0')
tensor([[ 3.0126, -2.7281]], device='cuda:0')
tensor([[ 2.9648, -2.6728]], device='cuda:0')
tensor([[ 2.9728, -2.6542]], device='cuda:0')
tensor([[ 3.0088, -2.6879]], devic

tensor([[ 3.0017, -2.7734]], device='cuda:0')
tensor([[ 2.6894, -2.2833]], device='cuda:0')
tensor([[ 2.9661, -2.7643]], device='cuda:0')
tensor([[ 3.0032, -2.7322]], device='cuda:0')
tensor([[ 3.0623, -2.6686]], device='cuda:0')
tensor([[ 2.9170, -2.6504]], device='cuda:0')
tensor([[ 3.0883, -2.7402]], device='cuda:0')
tensor([[ 1.0112, -0.4678]], device='cuda:0')
tensor([[ 2.9311, -2.4531]], device='cuda:0')
tensor([[-2.2580,  1.9846]], device='cuda:0')
tensor([[ 3.0542, -2.7272]], device='cuda:0')
tensor([[ 3.0280, -2.7217]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0782, -2.7103]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 1.3546, -1.0079]], device='cuda:0')
tensor([[ 2.9659, -2.6704]], device='cuda:0')
tensor([[ 3.0159, -2.5689]], device='cuda:0')
tensor([[ 3.0418, -2.7297]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.8868, -2.2626]], device='cuda:0')
tensor([[ 2.9976, -2.6969]], devic

tensor([[-2.2572,  1.9567]], device='cuda:0')
tensor([[ 3.0321, -2.7913]], device='cuda:0')
tensor([[ 2.9761, -2.7425]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9348, -2.7138]], device='cuda:0')
tensor([[ 2.6292, -2.2916]], device='cuda:0')
tensor([[ 2.6351, -2.2502]], device='cuda:0')
tensor([[-2.1470,  1.9080]], device='cuda:0')
tensor([[-2.1952,  1.9394]], device='cuda:0')
tensor([[ 2.6756, -2.4380]], device='cuda:0')
tensor([[ 2.9397, -2.6023]], device='cuda:0')
tensor([[ 2.9026, -2.5535]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0107, -2.7217]], device='cuda:0')
tensor([[ 2.6714, -2.1276]], device='cuda:0')
tensor([[ 2.6401, -2.4508]], device='cuda:0')
tensor([[ 2.9738, -2.6880]], device='cuda:0')
tensor([[-2.2013,  1.8131]], device='cuda:0')
tensor([[ 3.0209, -2.7351]], device='cuda:0')
tensor([[ 3.0528, -2.7234]], device='cuda:0')
tensor([[ 2.9365, -2.6705]], device='cuda:0')
tensor([[ 2.7475, -2.4376]], devic

tensor([[ 2.8699, -2.2565]], device='cuda:0')
tensor([[ 3.0292, -2.7518]], device='cuda:0')
tensor([[ 2.9984, -2.7796]], device='cuda:0')
tensor([[ 2.7739, -2.2698]], device='cuda:0')
tensor([[ 2.6217, -2.1411]], device='cuda:0')
tensor([[ 2.9465, -2.7150]], device='cuda:0')
tensor([[ 0.8870, -0.5508]], device='cuda:0')
tensor([[ 2.7704, -2.1662]], device='cuda:0')
tensor([[ 2.9450, -2.7195]], device='cuda:0')
tensor([[ 2.5978, -2.3758]], device='cuda:0')
tensor([[-2.2268,  2.0322]], device='cuda:0')
tensor([[ 2.7087, -2.2475]], device='cuda:0')
tensor([[ 2.6409, -2.2573]], device='cuda:0')
tensor([[ 0.4423, -0.3000]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 2.7138, -2.2199]], device='cuda:0')
tensor([[ 2.9147, -2.6085]], device='cuda:0')
tensor([[ 2.8799, -2.6074]], device='cuda:0')
tensor([[-2.1421,  1.8632]], device='cuda:0')
tensor([[ 3.0000, -2.7209]], device='cuda:0')
tensor([[ 3.0649, -2.7242]], device='cuda:0')
tensor([[-2.1757,  1.9430]], devic

tensor([[ 2.7708, -2.4572]], device='cuda:0')
tensor([[ 2.9689, -2.7027]], device='cuda:0')
tensor([[ 1.8141, -1.2841]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[-2.1772,  1.7428]], device='cuda:0')
tensor([[ 2.5892, -2.1051]], device='cuda:0')
tensor([[ 2.8838, -2.6422]], device='cuda:0')
tensor([[-2.1646,  1.8868]], device='cuda:0')
tensor([[ 0.7008, -0.3928]], device='cuda:0')
tensor([[ 2.6408, -2.2679]], device='cuda:0')
tensor([[ 2.4255, -1.7270]], device='cuda:0')
tensor([[ 2.9967, -2.6345]], device='cuda:0')
tensor([[ 2.7552, -2.3914]], device='cuda:0')
tensor([[ 2.9779, -2.5787]], device='cuda:0')
tensor([[-2.2385,  1.8037]], device='cuda:0')
tensor([[ 2.8020, -2.4667]], device='cuda:0')
tensor([[ 2.5786, -2.2140]], device='cuda:0')
tensor([[ 2.9215, -2.5382]], device='cuda:0')
tensor([[ 3.0440, -2.7570]], device='cuda:0')
tensor([[ 2.7597, -2.4551]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0159, -2.5689]], devic

tensor([[ 2.9700, -2.6114]], device='cuda:0')
tensor([[ 2.7991, -2.2588]], device='cuda:0')
tensor([[ 2.8582, -2.3620]], device='cuda:0')
tensor([[-2.2944,  1.8287]], device='cuda:0')
tensor([[ 2.6401, -2.4508]], device='cuda:0')
tensor([[ 3.0665, -2.7364]], device='cuda:0')
tensor([[ 3.0665, -2.7364]], device='cuda:0')
tensor([[ 3.0492, -2.7147]], device='cuda:0')
tensor([[ 2.7651, -2.0399]], device='cuda:0')
tensor([[ 2.8677, -2.5911]], device='cuda:0')
tensor([[ 3.0112, -2.6862]], device='cuda:0')
tensor([[ 2.8521, -2.6672]], device='cuda:0')
tensor([[ 2.7552, -2.3914]], device='cuda:0')
tensor([[ 2.7671, -2.2563]], device='cuda:0')
tensor([[ 2.6540, -2.0036]], device='cuda:0')
tensor([[ 2.7708, -2.4477]], device='cuda:0')
tensor([[ 2.9882, -2.7415]], device='cuda:0')
tensor([[ 3.0268, -2.7147]], device='cuda:0')
tensor([[ 2.7576, -2.5096]], device='cuda:0')
tensor([[ 2.7529, -2.0033]], device='cuda:0')
tensor([[ 2.9896, -2.6582]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], devic

tensor([[ 2.8653, -2.5954]], device='cuda:0')
tensor([[-2.2055,  1.8976]], device='cuda:0')
tensor([[-2.2153,  1.9626]], device='cuda:0')
tensor([[ 2.7478, -2.3423]], device='cuda:0')
tensor([[-2.1952,  1.9815]], device='cuda:0')
tensor([[ 3.0425, -2.7200]], device='cuda:0')
tensor([[ 1.8345, -1.4335]], device='cuda:0')
tensor([[ 2.9882, -2.7415]], device='cuda:0')
tensor([[ 2.9849, -2.6302]], device='cuda:0')
tensor([[ 3.0540, -2.7621]], device='cuda:0')
tensor([[ 2.9711, -2.7624]], device='cuda:0')
tensor([[ 2.9955, -2.7088]], device='cuda:0')
tensor([[ 2.5546, -2.1013]], device='cuda:0')
tensor([[ 2.9181, -2.7101]], device='cuda:0')
tensor([[ 2.8184, -2.5540]], device='cuda:0')
tensor([[ 2.4469, -1.6783]], device='cuda:0')
tensor([[-2.2248,  1.9836]], device='cuda:0')
tensor([[ 3.0446, -2.7906]], device='cuda:0')
tensor([[ 2.6401, -2.4508]], device='cuda:0')
tensor([[ 3.0275, -2.7379]], device='cuda:0')
tensor([[ 3.0464, -2.7057]], device='cuda:0')
tensor([[-2.2015,  1.8599]], devic

tensor([[ 3.0159, -2.5689]], device='cuda:0')
tensor([[ 3.0787, -2.7243]], device='cuda:0')
tensor([[ 2.9967, -2.6345]], device='cuda:0')
tensor([[ 2.9365, -2.6115]], device='cuda:0')
tensor([[ 2.6350, -2.2452]], device='cuda:0')
tensor([[ 3.0925, -2.7099]], device='cuda:0')
tensor([[ 2.7046, -2.1922]], device='cuda:0')
tensor([[ 3.0665, -2.7364]], device='cuda:0')
tensor([[ 2.9849, -2.6302]], device='cuda:0')
tensor([[ 2.8916, -2.6879]], device='cuda:0')
tensor([[ 2.8162, -2.5941]], device='cuda:0')
tensor([[ 2.7837, -2.2349]], device='cuda:0')
tensor([[-2.2819,  1.9476]], device='cuda:0')
tensor([[-2.0923,  1.6562]], device='cuda:0')
tensor([[ 3.0093, -2.6819]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 2.9828, -2.7592]], device='cuda:0')
tensor([[ 2.7812, -2.4317]], device='cuda:0')
tensor([[ 2.7087, -2.2475]], device='cuda:0')
tensor([[-2.1835,  2.0027]], device='cuda:0')
tensor([[ 3.0302, -2.7348]], device='cuda:0')
tensor([[ 3.0383, -2.6876]], devic

tensor([[ 2.7224, -2.4412]], device='cuda:0')
tensor([[ 2.9797, -2.6631]], device='cuda:0')
tensor([[ 2.5979, -2.3179]], device='cuda:0')
tensor([[ 2.7601, -2.4883]], device='cuda:0')
tensor([[ 2.8184, -2.5540]], device='cuda:0')
tensor([[ 2.8998, -2.6465]], device='cuda:0')
tensor([[-2.1593,  1.6387]], device='cuda:0')
tensor([[ 2.8798, -2.4783]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 3.0126, -2.7281]], device='cuda:0')
tensor([[ 2.8297, -2.5776]], device='cuda:0')
tensor([[ 2.6651, -2.1272]], device='cuda:0')
tensor([[ 2.8772, -2.6171]], device='cuda:0')
tensor([[ 2.6734, -2.3307]], device='cuda:0')
tensor([[ 3.0618, -2.7040]], device='cuda:0')
tensor([[ 3.0563, -2.6988]], device='cuda:0')
tensor([[-2.2028,  1.9367]], device='cuda:0')
tensor([[ 2.8211, -2.5329]], device='cuda:0')
tensor([[ 3.0504, -2.7039]], device='cuda:0')
tensor([[ 2.8297, -2.5776]], device='cuda:0')
tensor([[-2.2001,  1.9797]], device='cuda:0')
tensor([[ 2.9385, -2.6800]], devic

tensor([[ 3.0094, -2.7477]], device='cuda:0')
tensor([[-2.3073,  1.9742]], device='cuda:0')
tensor([[-2.1553,  1.8292]], device='cuda:0')
tensor([[-2.2450,  1.9472]], device='cuda:0')
tensor([[-2.2499,  1.9924]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 3.0779, -2.7099]], device='cuda:0')
tensor([[ 2.9849, -2.6302]], device='cuda:0')
tensor([[-2.2627,  1.9553]], device='cuda:0')
tensor([[ 2.9615, -2.7138]], device='cuda:0')
tensor([[ 2.3285, -2.1069]], device='cuda:0')
tensor([[ 3.0051, -2.7167]], device='cuda:0')
tensor([[ 2.5645, -1.8712]], device='cuda:0')
tensor([[ 1.1004, -0.8975]], device='cuda:0')
tensor([[-2.2148,  1.9686]], device='cuda:0')
tensor([[ 2.9986, -2.6967]], device='cuda:0')
tensor([[ 2.6837, -2.4312]], device='cuda:0')
tensor([[ 2.9397, -2.6023]], device='cuda:0')
tensor([[ 2.8012, -2.4526]], device='cuda:0')
tensor([[ 3.0159, -2.5689]], device='cuda:0')
tensor([[ 2.8521, -2.6672]], device='cuda:0')
tensor([[ 2.9197, -2.5708]], devic

tensor([[ 2.9622, -2.5964]], device='cuda:0')
tensor([[ 3.0190, -2.7131]], device='cuda:0')
tensor([[ 2.5786, -2.2140]], device='cuda:0')
tensor([[ 2.7601, -2.4883]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0185, -2.6651]], device='cuda:0')
tensor([[ 2.6258, -2.2433]], device='cuda:0')
tensor([[ 3.0774, -2.5476]], device='cuda:0')
tensor([[ 3.0010, -2.7136]], device='cuda:0')
tensor([[ 2.9812, -2.7209]], device='cuda:0')
tensor([[ 2.7708, -2.4477]], device='cuda:0')
tensor([[-2.2301,  1.9992]], device='cuda:0')
tensor([[ 1.7807, -1.2395]], device='cuda:0')
tensor([[ 3.0563, -2.6988]], device='cuda:0')
tensor([[ 3.0474, -2.5493]], device='cuda:0')
tensor([[-2.1968,  1.8959]], device='cuda:0')
tensor([[ 2.9761, -2.7425]], device='cuda:0')
tensor([[ 3.0197, -2.6896]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.8687, -2.6297]], device='cuda:0')
tensor([[ 2.9967, -2.6345]], device='cuda:0')
tensor([[ 2.8524, -2.6313]], devic

tensor([[ 2.7813, -2.5410]], device='cuda:0')
tensor([[ 2.9416, -2.6674]], device='cuda:0')
tensor([[ 2.9385, -2.6800]], device='cuda:0')
tensor([[-2.1658,  1.8524]], device='cuda:0')
tensor([[ 3.0665, -2.7364]], device='cuda:0')
tensor([[ 2.8916, -2.6879]], device='cuda:0')
tensor([[ 2.5978, -2.3758]], device='cuda:0')
tensor([[ 2.9604, -2.7631]], device='cuda:0')
tensor([[ 2.9162, -2.6442]], device='cuda:0')
tensor([[ 2.5915, -2.1859]], device='cuda:0')
tensor([[ 2.9016, -2.5926]], device='cuda:0')
tensor([[ 2.9881, -2.7563]], device='cuda:0')
tensor([[ 3.0403, -2.7201]], device='cuda:0')
tensor([[-2.1685,  1.9024]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.6401, -2.4508]], device='cuda:0')
tensor([[-2.1620,  1.8546]], device='cuda:0')
tensor([[ 2.8414, -2.5982]], device='cuda:0')
tensor([[ 2.8784, -2.3498]], device='cuda:0')
tensor([[-2.2528,  1.9844]], device='cuda:0')
tensor([[-2.2702,  1.8492]], device='cuda:0')
tensor([[ 3.0563, -2.6988]], devic

tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[-0.1816, -0.3041]], device='cuda:0')
tensor([[-2.2359,  1.9306]], device='cuda:0')
tensor([[-1.7162,  1.3738]], device='cuda:0')
tensor([[-2.2679,  2.0048]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 2.5943, -2.3641]], device='cuda:0')
tensor([[ 2.8696, -2.4366]], device='cuda:0')
tensor([[ 3.0268, -2.7147]], device='cuda:0')
tensor([[ 3.0392, -2.7130]], device='cuda:0')
tensor([[-2.1883,  1.9216]], device='cuda:0')
tensor([[ 3.0197, -2.6896]], device='cuda:0')
tensor([[-2.2152,  1.9502]], device='cuda:0')
tensor([[ 2.9458, -2.5892]], device='cuda:0')
tensor([[ 3.0280, -2.7217]], device='cuda:0')
tensor([[ 2.8615, -2.5874]], device='cuda:0')
tensor([[-2.1893,  1.8684]], device='cuda:0')
tensor([[ 2.6756, -2.4380]], device='cuda:0')
tensor([[ 2.8444, -2.5382]], device='cuda:0')
tensor([[ 2.6734, -2.3307]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 1.4317, -0.8679]], devic

tensor([[ 2.7552, -2.3914]], device='cuda:0')
tensor([[ 3.0141, -2.7526]], device='cuda:0')
tensor([[ 3.0618, -2.7040]], device='cuda:0')
tensor([[ 2.9311, -2.4531]], device='cuda:0')
tensor([[ 2.5242, -2.1756]], device='cuda:0')
tensor([[-2.2529,  1.6771]], device='cuda:0')
tensor([[ 2.7884, -2.3993]], device='cuda:0')
tensor([[-2.2467,  1.6466]], device='cuda:0')
tensor([[ 2.7651, -2.0399]], device='cuda:0')
tensor([[ 2.5945, -2.1603]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.7476, -2.2858]], device='cuda:0')
tensor([[ 3.0353, -2.7212]], device='cuda:0')
tensor([[-2.2162,  1.9522]], device='cuda:0')
tensor([[ 2.7552, -2.3914]], device='cuda:0')
tensor([[ 2.6734, -2.3307]], device='cuda:0')
tensor([[ 2.9214, -2.5855]], device='cuda:0')
tensor([[ 2.7357, -2.4646]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 3.0474, -2.5493]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.6653, -2.0854]], devic

tensor([[ 2.5187, -2.0321]], device='cuda:0')
tensor([[-2.3017,  1.9692]], device='cuda:0')
tensor([[ 2.5239, -2.0796]], device='cuda:0')
tensor([[ 2.9801, -2.6589]], device='cuda:0')
tensor([[ 2.7552, -2.3914]], device='cuda:0')
tensor([[ 3.0440, -2.7328]], device='cuda:0')
tensor([[ 2.9311, -2.4531]], device='cuda:0')
tensor([[-2.2350,  1.9809]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[-2.1201,  1.8015]], device='cuda:0')
tensor([[ 3.0328, -2.7042]], device='cuda:0')
tensor([[ 2.8533, -2.3480]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.7708, -2.4477]], device='cuda:0')
tensor([[ 2.9385, -2.6800]], device='cuda:0')
tensor([[ 2.9055, -2.5498]], device='cuda:0')
tensor([[ 2.9780, -2.6954]], device='cuda:0')
tensor([[ 2.9967, -2.6345]], device='cuda:0')
tensor([[ 2.3674, -1.4714]], device='cuda:0')
tensor([[ 2.9922, -2.7168]], device='cuda:0')
tensor([[ 2.9711, -2.6463]], device='cuda:0')
tensor([[ 2.9293, -2.5784]], devic

tensor([[ 2.9990, -2.5802]], device='cuda:0')
tensor([[ 2.7737, -2.1608]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0996, -2.7275]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[-2.1584,  1.9090]], device='cuda:0')
tensor([[ 3.0372, -2.7124]], device='cuda:0')
tensor([[ 3.0451, -2.6913]], device='cuda:0')
tensor([[ 2.6168, -2.3482]], device='cuda:0')
tensor([[ 2.9455, -2.5207]], device='cuda:0')
tensor([[ 2.9353, -2.6523]], device='cuda:0')
tensor([[ 2.9605, -2.6542]], device='cuda:0')
tensor([[-2.2263,  1.8690]], device='cuda:0')
tensor([[ 2.4900, -2.2105]], device='cuda:0')
tensor([[ 2.6845, -2.1865]], device='cuda:0')
tensor([[ 2.1590, -1.5566]], device='cuda:0')
tensor([[ 3.0615, -2.7079]], device='cuda:0')
tensor([[ 2.7812, -2.4317]], device='cuda:0')
tensor([[ 2.9922, -2.7168]], device='cuda:0')
tensor([[ 3.0109, -2.6035]], device='cuda:0')
tensor([[ 3.0714, -2.7515]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], devic

tensor([[ 2.5801, -2.0516]], device='cuda:0')
tensor([[ 2.6076, -2.3331]], device='cuda:0')
tensor([[-2.1414,  1.9197]], device='cuda:0')
tensor([[ 3.0696, -2.5900]], device='cuda:0')
tensor([[ 2.9502, -2.7263]], device='cuda:0')
tensor([[ 2.7721, -2.4903]], device='cuda:0')
tensor([[ 3.0009, -2.6904]], device='cuda:0')
tensor([[ 2.5732, -2.1342]], device='cuda:0')
tensor([[-2.1652,  1.9145]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.7406, -2.2676]], device='cuda:0')
tensor([[-2.1643,  1.5647]], device='cuda:0')
tensor([[ 3.0522, -2.6767]], device='cuda:0')
tensor([[ 2.5924, -2.3678]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[-2.2027,  1.8647]], device='cuda:0')
tensor([[ 0.4549, -0.4818]], device='cuda:0')
tensor([[ 2.7990, -2.1657]], device='cuda:0')
tensor([[-2.1553,  1.8292]], device='cuda:0')
tensor([[ 2.8772, -2.6171]], device='cuda:0')
tensor([[-2.1294,  1.6149]], device='cuda:0')
tensor([[ 2.9831, -2.6419]], devic

tensor([[ 2.9468, -2.7210]], device='cuda:0')
tensor([[-1.3229,  0.9699]], device='cuda:0')
tensor([[ 2.9662, -2.5976]], device='cuda:0')
tensor([[ 2.7080, -2.1873]], device='cuda:0')
tensor([[ 3.0618, -2.7040]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.5579, -2.0315]], device='cuda:0')
tensor([[ 1.7142, -1.3388]], device='cuda:0')
tensor([[ 2.5303, -2.2945]], device='cuda:0')
tensor([[ 2.9969, -2.5726]], device='cuda:0')
tensor([[ 2.6477, -2.2051]], device='cuda:0')
tensor([[ 2.6734, -2.3307]], device='cuda:0')
tensor([[ 2.9736, -2.7308]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[-2.1801,  1.9555]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0474, -2.5493]], device='cuda:0')
tensor([[ 3.0000, -2.7209]], device='cuda:0')
tensor([[ 2.9147, -2.6085]], device='cuda:0')
tensor([[-2.2588,  1.9666]], device='cuda:0')
tensor([[ 2.6842, -2.4571]], device='cuda:0')
tensor([[ 3.0217, -2.6603]], devic

tensor([[ 2.8341, -2.5243]], device='cuda:0')
tensor([[ 3.0163, -2.5325]], device='cuda:0')
tensor([[ 3.0368, -2.7278]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 3.0004, -2.6625]], device='cuda:0')
tensor([[ 2.7504, -2.3991]], device='cuda:0')
tensor([[ 2.6631, -2.3969]], device='cuda:0')
tensor([[ 2.9955, -2.7088]], device='cuda:0')
tensor([[ 2.9890, -2.7386]], device='cuda:0')
tensor([[-2.1258,  1.5963]], device='cuda:0')
tensor([[ 2.7433, -2.5417]], device='cuda:0')
tensor([[ 2.9026, -2.5535]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 2.9912, -2.7522]], device='cuda:0')
tensor([[-2.2698,  2.0308]], device='cuda:0')
tensor([[ 2.6526, -1.9919]], device='cuda:0')
tensor([[ 3.0015, -2.6915]], device='cuda:0')
tensor([[-2.1577,  1.5506]], device='cuda:0')
tensor([[-2.2123,  1.9073]], device='cuda:0')
tensor([[ 2.9731, -2.5871]], device='cuda:0')
tensor([[ 2.9237, -2.5107]], device='cuda:0')
tensor([[ 3.0447, -2.6782]], devic

tensor([[ 3.0073, -2.6844]], device='cuda:0')
tensor([[ 2.5264, -2.3478]], device='cuda:0')
tensor([[ 2.9728, -2.6542]], device='cuda:0')
tensor([[-2.2745,  2.0171]], device='cuda:0')
tensor([[ 2.9935, -2.7252]], device='cuda:0')
tensor([[ 2.8913, -2.5839]], device='cuda:0')
tensor([[ 2.5291, -2.0895]], device='cuda:0')
tensor([[ 2.3214, -1.6276]], device='cuda:0')
tensor([[-2.1812,  1.9574]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 2.6397, -2.3108]], device='cuda:0')
tensor([[ 2.8228, -2.3619]], device='cuda:0')
tensor([[ 2.8093, -2.6267]], device='cuda:0')
tensor([[-2.0747,  1.5854]], device='cuda:0')
tensor([[-2.2604,  2.0132]], device='cuda:0')
tensor([[-2.1823,  1.8673]], device='cuda:0')
tensor([[ 2.9193, -2.5933]], device='cuda:0')
tensor([[-2.1938,  1.9546]], device='cuda:0')
tensor([[ 3.0383, -2.7111]], device='cuda:0')
tensor([[ 2.7104, -2.1184]], device='cuda:0')
tensor([[ 2.9465, -2.7243]], device='cuda:0')
tensor([[ 2.8299, -2.6636]], devic

tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[-2.2445,  1.7487]], device='cuda:0')
tensor([[ 0.7796, -0.5100]], device='cuda:0')
tensor([[ 2.9647, -2.7021]], device='cuda:0')
tensor([[ 2.3339, -1.7667]], device='cuda:0')
tensor([[ 2.9967, -2.6345]], device='cuda:0')
tensor([[ 0.9181, -0.7169]], device='cuda:0')
tensor([[ 2.8916, -2.6879]], device='cuda:0')
tensor([[ 2.9455, -2.5207]], device='cuda:0')
tensor([[ 2.8051, -2.6335]], device='cuda:0')
tensor([[ 2.7552, -2.3914]], device='cuda:0')
tensor([[ 3.0489, -2.7225]], device='cuda:0')
tensor([[ 3.0203, -2.7173]], device='cuda:0')
tensor([[ 3.0408, -2.7238]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 3.0170, -2.6679]], device='cuda:0')
tensor([[-2.0955,  1.8051]], device='cuda:0')
tensor([[ 3.0984, -2.7065]], device='cuda:0')
tensor([[ 2.9961, -2.6345]], device='cuda:0')
tensor([[ 3.0131, -2.6481]], device='cuda:0')
tensor([[ 2.9311, -2.4531]], device='cuda:0')
tensor([[ 2.6651, -2.1272]], devic

tensor([[ 2.8429, -2.4846]], device='cuda:0')
tensor([[-2.1188,  1.4626]], device='cuda:0')
tensor([[ 2.8429, -2.4846]], device='cuda:0')
tensor([[-2.2096,  1.9328]], device='cuda:0')
tensor([[-2.0743,  1.7070]], device='cuda:0')
tensor([[-2.2074,  1.9561]], device='cuda:0')
tensor([[-1.9483,  1.4093]], device='cuda:0')
tensor([[ 2.7632, -2.4524]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9716, -2.6370]], device='cuda:0')
tensor([[ 3.0188, -2.6826]], device='cuda:0')
tensor([[ 2.9372, -2.7112]], device='cuda:0')
tensor([[ 2.7293, -1.9851]], device='cuda:0')
tensor([[ 3.0453, -2.7308]], device='cuda:0')
tensor([[ 3.0758, -2.6660]], device='cuda:0')
tensor([[ 2.6104, -2.0568]], device='cuda:0')
tensor([[ 3.0215, -2.7093]], device='cuda:0')
tensor([[-2.3146,  1.8559]], device='cuda:0')
tensor([[ 2.8031, -2.6132]], device='cuda:0')
tensor([[ 2.9288, -2.6605]], device='cuda:0')
tensor([[ 2.9830, -2.6804]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], devic

tensor([[ 3.0476, -2.7428]], device='cuda:0')
tensor([[ 2.9026, -2.5535]], device='cuda:0')
tensor([[ 3.0631, -2.6767]], device='cuda:0')
tensor([[ 2.9134, -2.5470]], device='cuda:0')
tensor([[ 2.8131, -2.4803]], device='cuda:0')
tensor([[ 2.6735, -2.4583]], device='cuda:0')
tensor([[ 2.9692, -2.6709]], device='cuda:0')
tensor([[ 3.0209, -2.7351]], device='cuda:0')
tensor([[-2.1861,  1.9353]], device='cuda:0')
tensor([[ 2.9180, -2.6867]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.7475, -2.5572]], device='cuda:0')
tensor([[ 3.0205, -2.7314]], device='cuda:0')
tensor([[ 2.7890, -2.5044]], device='cuda:0')
tensor([[ 2.8595, -2.5818]], device='cuda:0')
tensor([[ 2.6201, -2.2808]], device='cuda:0')
tensor([[ 2.9416, -2.6674]], device='cuda:0')
tensor([[ 3.0142, -2.7013]], device='cuda:0')
tensor([[ 3.0201, -2.7035]], device='cuda:0')
tensor([[-2.2731,  1.8816]], device='cuda:0')
tensor([[-0.0862, -0.1439]], device='cuda:0')
tensor([[ 2.8651, -2.5854]], devic

tensor([[ 2.6118, -2.2171]], device='cuda:0')
tensor([[ 2.5625, -1.7838]], device='cuda:0')
tensor([[-2.2435,  1.8993]], device='cuda:0')
tensor([[ 2.9428, -2.6255]], device='cuda:0')
tensor([[ 2.9783, -2.5576]], device='cuda:0')
tensor([[ 2.9336, -2.3072]], device='cuda:0')
tensor([[ 2.7828, -2.4927]], device='cuda:0')
tensor([[-1.9160,  1.5027]], device='cuda:0')
tensor([[ 3.0422, -2.7685]], device='cuda:0')
tensor([[ 2.7344, -2.2938]], device='cuda:0')
tensor([[ 2.9731, -2.5871]], device='cuda:0')
tensor([[ 2.7344, -2.2938]], device='cuda:0')
tensor([[ 2.7087, -2.2475]], device='cuda:0')
tensor([[ 1.6538, -1.3599]], device='cuda:0')
tensor([[ 3.0309, -2.7142]], device='cuda:0')
tensor([[ 2.7632, -2.4524]], device='cuda:0')
tensor([[ 3.0667, -2.7331]], device='cuda:0')
tensor([[ 2.4727, -1.8085]], device='cuda:0')
tensor([[ 2.9471, -2.5698]], device='cuda:0')
tensor([[ 2.9183, -2.6251]], device='cuda:0')
tensor([[-1.7162,  1.3738]], device='cuda:0')
tensor([[ 2.7467, -2.1236]], devic

text_list

In [64]:
test_result

[['[101, 118, 8211, 0, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 2544, 5131, 120, 4229, 120, 4397, 4403, 1048, 6527, 0, 102]', 1],
 ['[101, 769, 3211, 3209, 5169, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 1912, 2380, 0, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 1366, 0, 0, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 2875, 4277, 3862, 7804, 5114, 0, 0, 0, 0, 0, 102]', 1],
 ['[101, 5429, 5428, 2831, 7619, 0, 0, 0, 0, 0, 0, 102]', 1],
 ['[101, 8439, 120, 8111, 120, 8110, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 802, 4412, 131, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 1066, 123, 7517, 1912, 2380, 131, 0, 0, 0, 0, 102]', 0],
 ['[101, 8150, 118, 8211, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 109, 8290, 0, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 8113, 131, 8259, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 4634, 4873, 5998, 131, 100, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 7032, 7540, 0, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 102]', 0],
 ['[101, 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 102]

In [26]:
test_data = pd.read_csv('/home/jacklee/ocr/food_classification/train_data/test_data.csv')

In [27]:
test_data

Unnamed: 0,text,label
0,-32,0
1,微糖/熱/珍珠免費,0
2,交易明細,0
3,外帶,0
4,口,0
...,...,...
5394,1460.0,0
5395,零:,0
5396,交易明細,0
5397,澳洲洋葱(大),1


In [65]:
test_data['pre'] = 0
for i in range(len(test_dataset)):
    test_data['pre'][i] = test_result[i][1]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [70]:
a=0
for i in range(len(test_dataset)):   
    if test_data['label'][i] != test_data['pre'][i]:
        a+=1