In [1]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification

from tqdm import tqdm

# BERT模型分类

In [2]:
bert_tokenizer = BertTokenizer.from_pretrained(
    '../../H/models/huggingface/bert-base-chinese')

## 创建数据集

In [3]:
train_path = "../../H/datasets/THUCNews/train.txt"
dev_path = "../../H/datasets/THUCNews/dev.txt"
test_path = "../../H/datasets/THUCNews/test.txt"

In [28]:
# 加载语料
def load_corpus(path):
    sentences = []
    labels = []
    with open(path, 'r', encoding='UTF-8') as f:
        for line in tqdm(f):
            line = line.strip()
            if not line:
                continue
            sent, label = line.split('\t')
            sentences.append(sent)
            labels.append(int(label))
    return sentences, labels

train_data, train_labels = load_corpus(train_path)

180000it [00:00, 1428599.28it/s]


In [29]:
len(train_data),len(train_labels)

(180000, 180000)

In [30]:
len(set(train_labels))

10

In [6]:
train_data[900], train_labels[900]

('斯里兰卡急派外交部长前往巴基斯坦', 6)

In [7]:
# 向量化
def vectorize(sentences, tokenizer):
    input_ids = []

    for sent in sentences:
        encoded_sent = tokenizer.encode(
            sent,
            add_special_tokens=True,  # 添加特殊符号
        )
        input_ids.append(encoded_sent)
    return input_ids

input_ids = vectorize(train_data, tokenizer=bert_tokenizer)
input_ids[900]

[101,
 3172,
 7027,
 1065,
 1305,
 2593,
 3836,
 1912,
 769,
 6956,
 7270,
 1184,
 2518,
 2349,
 1825,
 3172,
 1788,
 102]

In [8]:
MAX_LEN = max([len(sen) for sen in input_ids])
print("Max sentence length: ", MAX_LEN)

Max sentence length:  35


In [9]:
# 填充成相同的长度
from tensorflow.keras.preprocessing.sequence import pad_sequences

print("Padding token: {:}, ID: {:}".format(bert_tokenizer.pad_token,
                                           bert_tokenizer.pad_token_id))
input_ids = pad_sequences(input_ids,
                          maxlen=MAX_LEN,
                          dtype='long',
                          value=0,
                          truncating="post",
                          padding='post')
input_ids.shape,input_ids[900]

Padding token: [PAD], ID: 0


((180000, 35),
 array([ 101, 3172, 7027, 1065, 1305, 2593, 3836, 1912,  769, 6956, 7270,
        1184, 2518, 2349, 1825, 3172, 1788,  102,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0]))

In [10]:
# 对应的掩码
# 填充对应的mask

def create_mask(input_ids):
    attention_masks = []
    for sent in input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]
        attention_masks.append(att_mask)
    return attention_masks

mask = create_mask(input_ids)
print(mask[900])

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [11]:
# 转化为 PyTorch 数据格式
import torch

train_inputs = torch.tensor(input_ids)
train_labels = torch.tensor(train_labels)
train_mask = torch.tensor(mask)
train_inputs.shape, train_labels.shape, train_mask.shape

(torch.Size([180000, 35]), torch.Size([180000]), torch.Size([180000, 35]))

In [13]:
# 创建数据管道

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

batch_size = 32
train_data = TensorDataset(train_inputs, train_mask, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data,
                              sampler=train_sampler,
                              batch_size=batch_size)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f644bf5e290>

In [14]:
len(train_dataloader)

5625

In [24]:
# 将整个过程整合起来


class DataGen:
    def __init__(self, path, tokenizer):
        self.tokenizer = tokenizer

        train_inputs, train_labels = self.load_corpus(path)
        self.train_labels = train_labels

        train_inputs = self.vectorize(train_inputs, self.tokenizer)
        MAX_LEN = max([len(seq) for seq in train_inputs])

        train_inputs = pad_sequences(
            train_inputs,
            maxlen=MAX_LEN,
            dtype='long',
            value=0,
            truncating="post",
            padding='post',
        )
        train_mask = self.create_mask(train_inputs)
        self.train_inputs = torch.tensor(train_inputs)
        self.train_labels = torch.tensor(train_labels)
        self.train_mask = torch.tensor(train_mask)

    def __call__(self, batch_size):
        train_data = TensorDataset(
            self.train_inputs,
            self.train_mask,
            self.train_labels,
        )
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(
            train_data,
            sampler=train_sampler,
            batch_size=batch_size,
        )
        return train_dataloader

    def load_corpus(self, path):
        sentences = []
        labels = []
        with open(path, 'r', encoding='UTF-8') as f:
            for line in tqdm(f):
                line = line.strip()
                if not line:
                    continue
                sent, label = line.split('\t')
                sentences.append(sent)
                labels.append(int(label))
        return sentences, labels

    def vectorize(self, sentences, tokenizer):
        input_ids = []

        for sent in sentences:
            encoded_sent = tokenizer.encode(
                sent,
                add_special_tokens=True,  # 添加特殊符号
            )
            input_ids.append(encoded_sent)
        return input_ids

    def create_mask(self, input_ids):
        attention_masks = []
        for sent in input_ids:
            att_mask = [int(token_id > 0) for token_id in sent]
            attention_masks.append(att_mask)
        return attention_masks


In [25]:
data_gen = DataGen(train_path, tokenizer=bert_tokenizer)
train_dataloader = data_gen(batch_size=32)

for data in train_dataloader:
    input_data, input_mask, input_labels = data
    print(input_data)
    print(input_mask)
    print(input_labels)
    break

180000it [00:00, 1415703.24it/s]


tensor([[ 101, 3119, 6397,  ...,    0,    0,    0],
        [ 101, 1649, 4448,  ...,    0,    0,    0],
        [ 101, 2797, 3952,  ...,    0,    0,    0],
        ...,
        [ 101,  683, 2157,  ...,    0,    0,    0],
        [ 101, 6205, 1298,  ...,    0,    0,    0],
        [ 101, 4263, 2894,  ...,    0,    0,    0]])
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
tensor([2, 2, 8, 2, 9, 0, 2, 0, 0, 2, 1, 6, 6, 3, 8, 5, 6, 0, 8, 9, 1, 2, 7, 6,
        0, 4, 7, 8, 6, 5, 1, 8])


In [38]:
data_gen = DataGen(dev_path, tokenizer=bert_tokenizer)
validation_dataloader = data_gen(batch_size=32)

for data in validation_dataloader:
    input_data, input_mask, input_labels = data
    print(input_data)
    print(input_mask)
    print(input_labels)
    break

10000it [00:00, 1240441.25it/s]


tensor([[  101,  1093,   772,  1501,  4764,  3309,  7481,   707,  6444,  3146,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  4640,  7716,   977,  6756,  1767,   677,  4028,  2661,  7789,
           671,  2391,  5125,  4868,  4567,  2891,  2797,  6841,  3324,  4640,
          7716,  1199,  7339,  7270,   102,     0,     0,     0,     0,     0],
        [  101,  7674,  6963,  1278,  4906,  1920,  2110,  8166,  2399,  7770,
          5440,  2497,  1357,  5310,  3362,  3389,  6418,  5143,  5320,  2458,
          6858,   102,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1921,  3823,  5428,  7032,  3959,  5401,  1863,  2270,  9111,
          2398,  6629,  5468,  2961,  1772,   817, 12075,  8129,   775,  8378,
          2835,   102,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1849,  1164,  4408,  1762,  735

## 创建模型

In [35]:
model = BertForSequenceClassification.from_pretrained(
    '../../H/models/huggingface/bert-base-chinese',  # 本地文件载入
    num_labels=10,
    output_attentions=False,
    output_hidden_states=False,
)
for param in model.parameters():
    param.requires_grad = True
    
model.cuda()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [37]:
from transformers import BertForSequenceClassification, AdamW, BertConfig

# 优化器
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)


# 学习率规划
from transformers import get_linear_schedule_with_warmup

epochs = 4

total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=total_steps)

In [39]:

# 预测精度
import numpy as np


def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


# 格式化时间显示
import time
import datetime


def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))

    return str(datetime.timedelta(seconds=elapsed_rounded))

In [47]:
import random


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed_val = 42
num_epochs = 2

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

loss_values = []

for epoch in range(0, num_epochs):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, num_epochs))
    print('Training...')

    t0 = time.time()

    total_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):

        if step % 1000 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(
                step, len(train_dataloader), elapsed))

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()

        outputs = model(b_input_ids,
                        token_type_ids=None,
                        attention_mask=b_input_mask,
                        labels=b_labels)
        # 返回 (loss,logits,hidden_state,attentions) 元组

        loss = outputs[0]

        total_loss += loss.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        scheduler.step()

    avg_train_loss = total_loss / len(train_dataloader)

    loss_values.append(avg_train_loss)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epcoh took: {:}".format(format_time(time.time() - t0)))

    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    model.eval()

    # Tracking variables
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    
    for batch in validation_dataloader:

        batch = tuple(t.to(device) for t in batch)

        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():

            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask)
        
        logits = outputs[0]

        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        tmp_eval_accuracy = flat_accuracy(logits, label_ids)

        eval_accuracy += tmp_eval_accuracy

        nb_eval_steps += 1

    print("  Accuracy: {0:.2f}".format(eval_accuracy / nb_eval_steps))
    print("  Validation took: {:}".format(format_time(time.time() - t0)))

print("")
print("Training complete!")


Training...
  Batch 1,000  of  5,625.    Elapsed: 0:01:28.
  Batch 2,000  of  5,625.    Elapsed: 0:02:57.
  Batch 3,000  of  5,625.    Elapsed: 0:04:26.
  Batch 4,000  of  5,625.    Elapsed: 0:05:55.
  Batch 5,000  of  5,625.    Elapsed: 0:07:24.

  Average training loss: 0.07
  Training epcoh took: 0:08:20

Running Validation...
(tensor([[-0.9750, -0.9032, -1.0428, -1.2282, -1.8091, -1.1258, -1.4433, 10.6276,
         -1.0413, -1.4416],
        [ 9.4536, -0.1385,  0.4910, -1.8662, -0.9697, -1.4404, -2.0414, -1.5558,
         -1.6219, -1.8161],
        [-1.5756, -0.9316, -1.4159, -0.6750, -1.1955, -1.2734, -2.1328, -0.1801,
         -1.4325,  9.7992],
        [ 8.9105, -1.1229,  0.2732, -1.3357,  0.3308, -1.4881, -1.9632, -1.4422,
         -2.0192, -1.9859],
        [ 9.1201, -1.1285,  1.0052, -1.8047,  0.0593, -1.5756, -2.3194, -1.4749,
         -1.7452, -1.9377],
        [ 1.8843,  8.9545, -1.2000, -0.8432, -2.4492, -1.0828, -0.0362, -1.2505,
         -1.9337, -1.5645],
        [-2.

(tensor([[ 5.1261,  1.1557,  5.0805, -3.6128,  1.2693, -2.6066, -1.4499, -2.7093,
         -2.9051, -2.8216],
        [ 0.2735,  0.6618,  8.9383, -2.7442, -0.2859, -1.6692, -1.7893, -2.2619,
         -1.8428, -1.8879],
        [-1.0824, -1.3471, -0.2617, -1.1427, -1.6825, -1.1703, -0.8613, 10.2826,
         -1.3380, -1.4640],
        [-1.3557, -1.0721, -1.0243, -0.8253, -1.0093, -0.3858,  9.8834, -1.2650,
         -1.5949, -0.9276],
        [-0.9394, -0.7933, -1.9942, 10.0387, -1.2387, -0.4629, -0.8100, -1.3999,
         -1.2540, -0.3405],
        [-0.8523, -1.2019, -1.0564, 10.2829, -1.7956, -1.0182, -0.6048, -1.1769,
         -0.8209, -1.0694],
        [-0.5252, -0.4917, -1.9085, -0.1552,  0.3460, -1.1178,  8.6665, -1.9582,
         -1.4863, -1.1723],
        [-2.1939, -1.1817, -0.9548, -1.4594,  0.9647, -1.1487, -1.3566, -1.2636,
          9.3669,  0.0387],
        [-0.4728,  0.5653,  8.9648, -2.3398,  0.4034, -2.1872, -1.8386, -2.4762,
         -1.2856, -1.9750],
        [-1.0551, 

(tensor([[-1.9168e-01,  1.0189e+01, -1.3950e+00, -1.3402e+00, -1.4206e+00,
         -1.1188e+00, -1.0341e+00, -8.7780e-01, -1.5277e+00, -9.9845e-01],
        [-8.0169e-01, -9.3854e-01, -8.4589e-01, -1.1665e+00, -1.6795e+00,
         -1.3462e+00, -1.3975e+00,  1.0502e+01, -9.8319e-01, -1.7063e+00],
        [-1.2984e+00, -1.1451e+00, -1.3700e+00, -9.3563e-01, -1.3476e+00,
         -1.3126e+00, -2.1390e+00,  1.0267e+00, -1.9423e+00,  9.5415e+00],
        [-1.0723e+00, -7.7108e-01, -7.6298e-01, -1.0244e+00, -5.8344e-01,
         -1.3740e+00,  9.8208e+00, -1.0172e+00, -1.7242e+00, -1.1806e+00],
        [-1.5470e+00, -1.3516e+00, -2.1107e+00,  9.2334e-01, -7.1960e-01,
          9.6140e+00, -2.9882e-01, -1.3498e+00, -6.4332e-01, -8.9871e-01],
        [ 9.0381e+00, -8.0108e-01,  1.2352e+00, -1.5964e+00, -1.5095e-02,
         -1.8108e+00, -2.0408e+00, -1.5929e+00, -2.3533e+00, -2.2924e+00],
        [-2.2689e+00, -1.3138e+00, -3.3731e-01, -1.0686e+00,  1.0133e+00,
         -1.4149e+00, -1.2866e+

(tensor([[-0.9673, -2.2465, -0.9163, -1.8792,  8.8773, -1.5512, -0.5083, -1.6086,
          0.2796, -1.9372],
        [-1.0243, -1.0900, -1.2932,  9.9998, -1.8007, -0.7412, -0.9021, -1.0699,
         -1.0525, -0.5707],
        [-0.2641, -0.0948, -1.5465, -0.7122, -0.8684,  9.7163, -0.7829, -1.7554,
         -1.1295, -1.2779],
        [-1.4781, -1.2270, -1.3913, -1.1262, -1.0024, -1.1212, -2.2514,  1.1016,
         -1.8822,  9.5198],
        [-1.1318, -0.9206, -1.9938, -0.2816, -1.0653,  9.9188, -0.2238, -1.1908,
         -1.0924, -0.8405],
        [ 9.1062, -0.1825, -0.5828, -1.4293, -1.2187, -0.5652, -1.8446, -0.9332,
         -1.7542, -1.1324],
        [-0.7365, 10.1736, -1.1950, -1.5133, -1.2287, -1.3586, -1.0011, -0.7618,
         -1.1627, -0.9560],
        [-0.9174, -1.2371, -1.1300, -1.1814, -1.5694, -1.2857, -1.2208, 10.6161,
         -1.1482, -1.4607],
        [-1.3241, -2.1081, -0.2101, -1.8940,  9.6309, -1.6575, -1.0558, -1.6454,
         -0.3801, -2.3379],
        [-1.6927, 

(tensor([[ 9.4299e+00, -3.5895e-01, -2.9501e-01, -1.6656e+00, -1.2004e+00,
         -8.0532e-01, -1.6459e+00, -1.2724e+00, -1.7782e+00, -1.5043e+00],
        [-1.1412e+00, -8.5371e-01, -1.1635e+00,  1.0268e+01, -1.8656e+00,
         -1.2197e+00, -6.4969e-01, -1.1888e+00, -4.1875e-01, -9.7583e-01],
        [ 3.3789e-01, -2.1720e+00, -7.2531e-02, -1.1901e+00,  8.8905e+00,
         -1.2189e+00, -1.4948e+00, -2.0710e+00, -5.1981e-01, -2.9851e+00],
        [-9.4788e-01, -7.2717e-01, -8.7591e-01, -1.8563e+00, -1.2456e+00,
         -9.5717e-01, -1.7729e+00, -7.3006e-02, -2.0151e+00,  9.5902e+00],
        [-5.1144e-02, -1.0202e+00,  9.4619e+00, -1.7178e+00, -8.0601e-01,
         -1.7365e+00, -1.0249e+00, -2.3128e+00, -1.1715e+00, -1.5587e+00],
        [-1.4198e+00, -8.8238e-01, -1.2145e+00,  1.0293e+01, -1.6933e+00,
         -1.2610e+00, -3.7840e-01, -1.3399e+00, -5.9520e-01, -7.6823e-01],
        [-7.4854e-01, -8.3803e-01, -1.7494e+00, -3.9402e-01, -1.1085e+00,
          9.8125e+00, -1.3392e-

(tensor([[-1.2108e+00, -7.7438e-01, -1.0949e+00,  1.0232e+01, -1.7432e+00,
         -1.1843e+00, -7.0246e-01, -1.4833e+00, -8.0318e-01, -9.3744e-01],
        [ 2.9784e-01,  9.6135e+00, -1.6797e+00, -5.5654e-01, -1.7670e+00,
         -1.4750e+00, -5.6864e-01, -6.8625e-01, -2.3827e+00, -8.3630e-01],
        [-8.9802e-01, -1.1221e+00, -1.1367e+00, -1.1849e+00, -1.6167e+00,
         -1.2080e+00, -1.4050e+00,  1.0634e+01, -1.1492e+00, -1.3320e+00],
        [-2.0194e+00, -1.1352e+00, -3.7597e-01, -9.7876e-01, -8.6691e-02,
         -7.4835e-01, -1.1947e+00, -1.1568e+00,  1.0165e+01, -5.5061e-01],
        [-2.4751e-01, -4.7908e-01,  9.4598e+00, -2.2253e+00, -7.6973e-01,
         -1.7966e+00, -1.2877e+00, -2.4227e+00, -1.1979e+00, -1.2772e+00],
        [-4.3693e-01,  8.8400e-02, -8.8902e-01, -1.5814e+00, -9.7164e-01,
         -8.6490e-01,  9.6723e+00, -1.2984e+00, -1.7599e+00, -1.6139e+00],
        [-1.9167e+00, -1.0325e+00, -2.2690e-01, -1.1356e+00,  4.0153e-02,
         -8.5924e-01, -1.2556e+

(tensor([[-2.3703e-01,  1.0176e+01, -1.5842e+00, -1.3968e+00, -1.3131e+00,
         -1.1301e+00, -1.1216e+00, -8.0131e-01, -1.1797e+00, -1.0801e+00],
        [ 1.7019e+00,  9.5842e-01,  7.3607e+00, -2.6605e+00, -1.4415e+00,
         -2.4497e+00,  1.4118e+00, -2.9205e+00, -2.4653e+00, -1.9722e+00],
        [ 2.3559e+00, -7.2258e-01,  4.1404e+00, -1.8817e+00, -1.3674e+00,
         -1.4333e+00,  6.7594e+00, -2.5806e+00, -2.6033e+00, -3.1198e+00],
        [-1.0889e+00, -1.4427e+00, -1.3706e+00, -1.0604e+00, -1.5627e+00,
         -1.0536e+00, -1.4286e+00,  1.0678e+01, -1.3306e+00, -1.0385e+00],
        [-1.0369e-01,  1.0178e+01, -1.6703e+00, -1.2280e+00, -1.4441e+00,
         -1.1315e+00, -9.0817e-01, -8.7896e-01, -1.6041e+00, -9.8678e-01],
        [-8.3373e-01, -1.8957e+00, -1.5688e+00, -1.4675e+00, -1.7560e+00,
          2.0861e+00, -1.6367e+00,  7.1620e+00, -2.2892e+00,  2.4323e+00],
        [ 8.0104e-02, -2.3076e-01,  6.8666e+00, -2.6784e+00,  4.7881e+00,
         -2.7252e+00, -1.8472e+

(tensor([[ 0.9461, -0.3526,  1.9509, -1.8520, -0.2802, -2.1301,  8.2894, -2.4452,
         -1.7428, -3.0328],
        [-1.4011, -1.3445, -1.8922, -0.4000,  0.9929, -1.5468,  9.2445, -1.0695,
         -1.4133, -1.0125],
        [-0.9497, -1.5515, -1.7226, 10.0515, -1.5658, -0.7656,  0.1180, -1.2202,
         -1.0096, -0.4277],
        [-2.3335, -1.1119,  1.0034, -0.7600,  0.6031, -1.6862, -1.8708, -1.4063,
          9.4101, -0.8654],
        [-0.9462, -1.0524, -0.5053, -1.2771, -1.6142, -1.4348, -1.4577, 10.5473,
         -1.0394, -1.4749],
        [-2.1525, -0.9106, -0.4337, -1.1011,  0.3613, -0.8293, -1.2486, -1.3344,
         10.1883, -0.7178],
        [ 8.6644,  1.5248,  0.6946, -2.3560, -0.5124, -2.0564, -2.1618, -1.6364,
         -2.2730, -2.0223],
        [ 9.2852, -0.5810,  1.5051, -2.0405, -0.6167, -2.1089, -1.9762, -1.6204,
         -2.2831, -1.9240],
        [-1.0983,  0.4093, -0.7524, -1.9982, -0.9415,  8.1051,  2.3336, -1.8026,
         -1.7459, -1.2041],
        [-2.0490, 

(tensor([[-2.1984, -1.0880, -0.4440, -1.0836,  0.0903, -0.8578, -1.2120, -1.1647,
         10.2151, -0.3616],
        [-1.0592, -2.2413, -0.6546, -1.4966,  9.8260, -1.3908, -0.9378, -1.5770,
         -0.7170, -2.3108],
        [-2.0944, -1.3164, -0.4264, -0.9382,  0.0999, -0.6620, -1.1521, -1.1937,
         10.1848, -0.6034],
        [-1.1083,  0.1478, -1.5643, -0.7426, -1.2293,  9.6866, -0.3506, -1.2050,
         -1.1247, -1.1869],
        [-0.4249, -0.9322, -0.7327, -1.9411, -1.0470, -1.1824, -2.0304,  0.2154,
         -2.2538,  9.2854],
        [ 0.2051,  1.1660,  8.3714, -2.4774,  1.3276, -2.2295, -2.1758, -2.5765,
         -1.5405, -2.7542],
        [-1.2325, -1.1771, -1.2745, -1.0640, -1.3438, -0.6899, -1.8587,  0.1023,
         -1.9832,  9.7614],
        [ 9.3392, -0.6867,  1.1835, -1.9359, -0.6345, -1.8388, -2.0793, -1.3258,
         -2.1348, -1.8954],
        [-2.0622, -0.5808, -1.4721, -1.0525, -1.3729, -0.2345, -2.3430, -0.5448,
         -1.2697,  9.7615],
        [-0.6549, 

(tensor([[-1.2008, -2.4619, -3.2399,  2.4441,  0.1220,  6.6640,  2.4511, -2.8669,
         -0.0463, -1.3868],
        [-0.1582, -1.0344,  9.3654, -2.1016, -0.1441, -2.1114, -0.7315, -2.6527,
         -1.2133, -1.6851],
        [-0.6791, -0.6508, -0.4624, -1.8716, -0.5301,  0.3777,  9.5639, -2.1125,
         -1.7409, -1.8328],
        [-0.1554, 10.1542, -1.6444, -1.2473, -1.3576, -1.1375, -1.0044, -0.8849,
         -1.5144, -1.0066],
        [ 0.0306,  9.7505, -1.0039, -1.7822, -0.9863, -1.5802, -1.4337, -0.6415,
         -1.9478, -0.5991],
        [-0.6497,  9.9756, -1.3473, -1.2829, -1.0019, -1.6203, -1.2336, -0.7179,
         -1.1235, -0.9107],
        [-1.0711,  0.7636, -0.0655, -2.5353,  1.2491, -1.5035,  7.5137, -1.0617,
         -2.1743, -0.8092],
        [ 0.8855,  0.1597, -1.5410, -1.4956, -0.4266, -1.3716,  7.7139, -0.1900,
         -1.5214, -1.2902],
        [-1.2569, -1.0067, -1.4323, 10.3584, -1.4487, -0.9672, -0.3506, -1.4078,
         -0.5822, -0.8709],
        [ 9.0045, 

(tensor([[-9.4574e-01, -2.5334e+00, -9.3309e-01, -1.2537e+00,  9.8192e+00,
         -1.4086e+00, -8.2363e-01, -1.5823e+00, -5.1625e-01, -2.1488e+00],
        [-6.2565e-01,  1.0211e+01, -1.3059e+00, -1.2695e+00, -1.2297e+00,
         -1.2261e+00, -9.7336e-01, -9.2151e-01, -1.2330e+00, -1.0625e+00],
        [ 1.3471e+00, -1.1515e+00,  8.7338e+00, -2.6221e+00, -9.4830e-01,
         -2.4239e+00,  1.2286e+00, -2.7157e+00, -2.1330e+00, -1.5515e+00],
        [-8.6702e-01, -8.9782e-01, -8.1275e-01, -1.3832e+00, -1.7116e+00,
         -1.2330e+00, -1.3668e+00,  1.0582e+01, -1.0248e+00, -1.4438e+00],
        [-1.4843e+00, -1.6632e+00, -3.0405e-01, -1.7119e+00, -1.2802e+00,
         -5.2584e-02, -1.6603e+00, -1.0731e+00, -1.4566e+00,  9.4055e+00],
        [ 6.9567e+00, -6.2895e-01, -1.3846e+00,  6.6081e-01,  1.6117e-01,
          1.4420e+00, -2.6891e+00, -1.9754e+00, -2.9692e+00, -1.0133e+00],
        [-1.4790e+00, -2.3707e+00, -6.3178e-01, -1.5165e+00,  5.1104e+00,
         -1.0970e+00, -1.8637e+

(tensor([[-1.0826, -0.8129, -1.0482, -1.0769, -1.7297, -1.3944, -1.4001, 10.5616,
         -1.0537, -1.2109],
        [-0.7265, -2.2957, -2.9994,  1.1574,  1.2022,  6.6405,  2.8376, -1.8667,
         -2.1613, -0.6400],
        [-1.3212, -0.6704, -1.4423, -1.0136, -1.2562, -0.9934, -2.0834,  0.2999,
         -1.9670,  9.7490],
        [-1.0130, -2.2079, -0.8956, -1.4011,  9.7880, -1.6139, -1.0072, -1.2765,
         -1.0165, -1.9833],
        [-1.9031, -0.7994, -1.8730, -0.8394, -0.3317, -1.2459, -1.9958, -1.9309,
          4.2875,  5.3153],
        [-1.2617, -0.7484, -1.4760, 10.1879, -1.4037, -0.9258, -0.8142, -1.1964,
         -1.0458, -0.5353],
        [-1.5186, -0.4127, -1.4722, -1.5717, -1.5930,  0.0857, -1.8361, -0.1662,
         -1.4331,  9.6452],
        [-0.5869, -0.8915,  0.4981, -1.7221, -0.7587, -1.4142,  9.5954, -1.6144,
         -2.0178, -1.3576],
        [-1.3612, -1.4372, -2.5720,  0.4711, -0.3723,  9.0279,  2.0036, -1.9458,
         -1.1756, -1.1699],
        [-1.0977, 

(tensor([[-7.1601e-01,  9.8920e+00, -1.3030e+00, -1.3844e+00, -1.2793e+00,
         -1.1028e+00, -6.4816e-01, -9.6178e-01, -1.4774e+00, -9.8022e-01],
        [-1.2857e+00, -1.1135e+00, -1.2830e+00,  1.0313e+01, -1.5987e+00,
         -8.8839e-01, -6.4324e-01, -1.2920e+00, -7.6389e-01, -7.7768e-01],
        [-8.5005e-01, -1.8933e+00, -1.2383e+00, -9.6943e-01,  2.1189e+00,
          8.4828e+00, -9.7758e-01, -1.7718e+00, -8.9961e-01, -1.4966e+00],
        [-8.4735e-01,  5.7364e-01,  8.6540e+00, -1.9284e+00,  5.0813e-01,
         -1.9538e+00, -1.7511e+00, -2.6463e+00, -9.4373e-01, -2.0765e+00],
        [ 9.2437e+00, -9.9225e-01,  6.1741e-01, -1.6145e+00, -2.2618e-01,
         -1.6040e+00, -1.8913e+00, -1.3704e+00, -2.1043e+00, -1.9082e+00],
        [-9.5001e-01, -9.8542e-01, -1.1526e+00, -8.9925e-01, -4.5673e-01,
         -1.3536e+00,  9.8619e+00, -4.3850e-01, -2.0611e+00, -8.2767e-01],
        [-3.4296e-01,  1.0128e+01, -1.7623e+00, -9.9861e-01, -1.8651e+00,
         -7.1358e-01, -7.7017e-

(tensor([[-1.9917, -1.0113, -1.6801, -1.4498, -1.7948, -1.4937, -1.9946,  2.7353,
         -1.1296,  8.6821],
        [ 9.1046, -1.0030,  0.8208, -1.6837, -0.1960, -1.5600, -1.8565, -1.5001,
         -2.0813, -2.0734],
        [-1.1026, -0.5605, -1.7137, -0.5620, -1.2157,  9.9491, -0.1257, -1.4039,
         -1.1733, -0.7902],
        [-1.2316, -0.8330, -1.1181, -0.4455, -0.8806, -0.5665,  9.8467, -1.4435,
         -1.6389, -1.3060],
        [-2.1300, -1.4115, -0.9358, -1.1266,  2.2877, -0.8449, -2.2589, -1.5939,
          9.4199, -0.4015],
        [-0.6639, -1.1520, -1.1769, -0.8649, -0.5124,  9.4449, -0.6874, -1.3663,
         -0.9166, -1.1442],
        [-2.1361, -1.1012, -0.4371, -0.6552,  0.1365, -1.0823, -1.0551, -1.2494,
         10.0694, -0.5894],
        [ 9.1713, -1.1232,  0.1440, -1.2642, -0.5978, -1.3497, -1.9428, -1.1633,
         -1.6662, -1.5734],
        [-2.1342, -1.0165, -0.1930, -1.0580, -0.3767, -0.7080, -0.9601, -1.0229,
         10.1111, -0.4378],
        [-2.0320, 

(tensor([[-2.0044, -1.4111, -1.5986, -1.8162,  0.1808, -1.7280, -2.0636,  0.8185,
          0.1027,  7.4894],
        [ 9.3641, -0.4747,  0.2078, -1.5538, -1.1925, -1.2089, -1.5713, -1.5136,
         -1.5583, -1.6916],
        [-0.7230, -1.0034, -0.8731, -1.3491, -1.5178, -1.3229, -1.3033, 10.4970,
         -1.4111, -1.5198],
        [-1.5574, -0.9738, -1.6959, -0.8652, -1.7545, -0.9533, -1.9540,  1.4627,
         -1.7813,  9.3237],
        [-1.4873, -1.3448, -1.7458, -0.7520, -1.0025, -1.4678, -1.4975,  0.5558,
         -1.9894,  9.6353],
        [-0.6397, -0.8445, -1.0295, -1.0550, -0.3749, -1.4061,  9.7977, -1.1789,
         -1.5600, -1.5253],
        [-0.8467, -1.0873, -0.6391, -1.5062, -1.3229, -1.5045, -1.2260, 10.5019,
         -1.3897, -1.5256],
        [-0.6831, 10.1232, -0.7096, -1.5595, -1.1668, -1.2249, -1.3099, -0.7550,
         -1.4907, -0.9447],
        [-1.1370, -1.2433, -1.2398, -0.8851, -1.5122, -1.0184, -1.5034, 10.6154,
         -1.0546, -1.3150],
        [ 0.5017, 

(tensor([[-1.1826e+00, -1.4108e+00, -1.0418e+00, -1.0996e+00, -1.5508e+00,
         -1.1252e+00, -9.7478e-01,  1.0589e+01, -1.3806e+00, -1.2913e+00],
        [-8.4846e-01, -1.2589e+00, -1.0753e+00,  1.0223e+01, -1.8822e+00,
         -9.7321e-01, -3.0888e-01, -1.1126e+00, -1.0217e+00, -7.4379e-01],
        [-1.5127e+00, -7.4046e-01, -9.2543e-01, -1.7321e+00, -9.0417e-01,
         -3.9625e-01, -1.8352e+00, -7.2565e-01, -1.6292e+00,  9.6730e+00],
        [ 2.1588e-01, -1.1655e+00,  9.9941e-01, -2.5571e+00,  8.5294e+00,
         -1.5663e+00, -1.3684e+00, -2.3641e+00, -4.3238e-01, -3.5011e+00],
        [-1.6795e-01, -1.4536e+00,  2.3801e+00, -1.9269e+00,  8.3467e+00,
         -1.2620e+00, -1.6698e+00, -2.9195e+00, -6.1707e-01, -3.6289e+00],
        [-2.2244e+00, -9.8560e-01,  4.1646e-01, -1.9511e+00,  1.5030e+00,
         -1.5999e+00, -1.5885e+00, -1.8629e+00,  9.1123e+00, -1.3473e-01],
        [ 3.2767e-01, -7.3182e-01,  9.4420e+00, -2.0616e+00, -4.0375e-01,
         -1.8129e+00, -1.5210e+

(tensor([[-1.1876e+00, -9.2109e-01, -1.6379e+00,  1.0317e+01, -1.4296e+00,
         -9.8724e-01, -2.0637e-01, -1.1784e+00, -9.1899e-01, -8.0065e-01],
        [ 9.0535e+00, -7.2555e-01,  2.6271e+00, -1.5749e+00, -1.3313e+00,
         -2.4075e+00, -1.9497e+00, -1.7772e+00, -1.3853e+00, -2.4287e+00],
        [-1.6210e+00, -1.0077e+00, -1.0285e+00, -1.3132e+00, -9.5948e-01,
         -1.3031e+00, -2.2717e+00,  2.3874e-01, -1.6175e+00,  9.6965e+00],
        [-1.2489e+00, -1.5703e+00, -6.6907e-01, -1.3181e+00, -1.1427e+00,
         -1.5330e+00, -2.1914e+00, -2.6225e-01, -9.3134e-01,  9.4962e+00],
        [-5.6885e-01,  1.0182e+01, -1.5290e+00, -1.3634e+00, -1.1138e+00,
         -1.2190e+00, -1.0069e+00, -7.5810e-01, -1.3143e+00, -1.0088e+00],
        [-2.3541e+00, -1.3473e+00, -2.0463e-01, -2.4198e+00,  8.1699e+00,
         -2.3866e+00, -2.2694e+00, -1.7679e+00,  3.2982e+00, -6.5391e-01],
        [-1.3134e+00, -4.3745e-01, -1.4940e+00,  1.0108e+01, -1.8475e+00,
         -9.9144e-01, -1.2176e+

(tensor([[-2.8183, -2.0992, -1.1949, -1.1384,  7.0278, -1.7720, -0.8845, -1.4545,
          5.5953, -1.7430],
        [-1.5451, -0.6121, -2.1873, 10.2281, -1.3652, -0.5106, -0.8842, -1.3142,
         -0.7371, -0.5443],
        [-1.9764, -1.1890, -0.7545, -1.0631,  0.6660, -0.8830, -1.3192, -1.5084,
         10.1843, -0.5389],
        [ 8.9886,  0.3626, -0.5447, -1.8284, -1.1124, -1.1599, -1.7936, -0.8901,
         -1.8983, -1.2574],
        [-1.1353, -1.1776, -1.8172, 10.3075, -1.7003, -0.5515,  0.1015, -1.2287,
         -0.9875, -0.7405],
        [-0.3864, 10.1461, -1.7534, -1.3408, -1.1685, -0.9578, -1.1413, -0.7720,
         -1.3486, -0.9763],
        [ 9.2437, -0.8317,  0.4472, -1.4895, -0.5463, -1.4792, -2.0365, -1.3656,
         -1.7372, -1.5796],
        [-2.0578, -1.4652, -1.5345, -0.7729, -0.1122, -0.7502, -1.9813, -0.8683,
         -1.3862,  9.4992],
        [ 3.0506,  8.7292, -0.1187, -2.3349, -1.3513, -0.8621, -1.9946, -1.2194,
         -2.5655, -1.6149],
        [ 1.7366, 

(tensor([[-1.5108, -0.9637, -1.3246, 10.3077, -2.0472, -0.5175, -0.7432, -0.9971,
         -0.6337, -0.6590],
        [-0.9583, -1.2190, -1.0757, -1.1205, -1.7719, -1.2250, -1.3922, 10.6180,
         -0.8631, -1.1713],
        [-1.2672,  9.7832, -0.8922, -1.2442, -1.3360, -1.0847, -0.9812, -0.7543,
         -0.6632, -1.0687],
        [-0.8804, -0.3701, -1.6596, -0.7669, -1.2776,  9.8946, -0.2814, -1.3712,
         -1.4537, -0.7970],
        [-0.6520, -1.6682,  7.3940, -3.0032,  1.7619, -1.4529,  0.1645, -0.5260,
         -1.4714, -2.1298],
        [-1.4227, -0.4363, -1.0488,  9.5147, -1.9431, -1.1462, -0.2494, -1.4730,
         -0.6141, -1.0749],
        [-0.8369, -0.4763, -1.8834, -0.5986, -0.7218,  9.7976, -0.4043, -1.5117,
         -1.2249, -1.1470],
        [-0.5471, 10.1623, -1.7865, -1.1759, -1.2505, -1.1305, -0.8622, -0.9173,
         -1.4109, -0.8547],
        [-1.3162, -0.8371, -2.0628, -0.7645, -1.7244, -1.1647, -2.3687,  2.9291,
         -1.8335,  8.6113],
        [-2.1956, 

(tensor([[-2.4543, -1.9042,  0.3854, -1.3977,  2.7969, -1.6079, -1.7573, -1.9624,
          9.3301, -0.9839],
        [-1.6162, -0.6187, -1.6653, -0.4759, -0.4048, -0.8563,  9.8510, -1.0629,
         -1.6776, -0.9693],
        [-2.2281, -1.2619, -0.4285, -0.8740,  0.1122, -0.7748, -1.0724, -1.2340,
         10.2224, -0.6274],
        [-1.1135, -2.8122, -1.3334, -1.9764,  4.6609,  3.1018, -1.6914, -2.1343,
         -0.9488,  3.5811],
        [-1.4823, -1.2421, -1.4033, 10.2360, -1.2785, -0.6635, -0.7911, -1.1216,
         -0.6488, -0.7043],
        [-0.8882, -1.1364, -1.5912, -1.1349, -1.5048, -1.0229, -1.4328, 10.6325,
         -1.2240, -1.1154],
        [-0.8374, -0.9595, -0.3169, -1.0883, -0.5767, -1.4476,  9.9459, -1.1884,
         -1.7587, -1.3833],
        [-1.6094, -2.7740, -2.3670,  0.3979,  7.5678,  0.5929,  0.8932, -1.7649,
         -1.0603, -1.5321],
        [-0.9991,  0.0547, -1.8846, -0.7821, -0.9356,  9.6903, -0.4500, -1.5964,
         -1.3414, -0.9379],
        [ 2.4958, 

(tensor([[-2.4051, -1.0452, -0.7134, -1.0992,  0.6085, -0.9659, -1.1379, -1.3448,
         10.0563, -0.1873],
        [ 0.8839, -0.1081,  9.3990, -2.1867, -1.0730, -1.9851, -1.6784, -2.2962,
         -1.9386, -1.4108],
        [-2.1506, -2.2691, -1.4630, -0.9070,  2.8178, -0.2623, -1.0585, -1.4386,
          8.8257, -1.2420],
        [-0.9492, -0.7632, -1.0366, 10.1104, -1.9790, -1.4865, -0.6264, -1.1598,
         -0.8213, -0.7447],
        [-1.1105, -0.6563, -1.9316, -0.5538, -1.0207,  9.8704,  0.0358, -1.4473,
         -1.1263, -0.8817],
        [-1.2748, -1.0965, -1.7885, -0.7156, -1.1837,  9.6908,  0.4228, -0.9394,
         -0.9533, -0.8008],
        [-1.2068, -1.1263, -1.1076, 10.3184, -1.9478, -0.9727, -0.7009, -1.3026,
         -0.5993, -0.8284],
        [-1.1571, -2.3776, -0.3573, -1.4723,  9.8818, -1.5809, -0.8666, -1.7754,
         -0.4338, -2.1875],
        [-0.7887, -2.0414, -1.2066, -0.7681,  1.9307,  0.9017,  7.9346, -1.6289,
         -1.5792, -1.9927],
        [ 9.2110, 

(tensor([[-2.2357e+00, -1.0841e+00, -1.7295e-01, -8.8017e-01, -1.6653e-01,
         -7.6715e-01, -1.1429e+00, -1.2457e+00,  1.0116e+01, -4.4833e-01],
        [-1.7909e+00, -1.5029e+00, -1.0002e+00, -1.5221e+00, -9.0686e-01,
         -6.1517e-01, -2.3026e+00,  2.2821e-01, -1.1455e+00,  9.5942e+00],
        [-1.0985e+00, -5.7561e-01, -1.5215e+00, -9.6581e-01, -1.9522e+00,
         -1.3028e+00, -1.5429e+00,  1.0554e+01, -8.4713e-01, -1.1636e+00],
        [-1.2410e+00, -1.0336e+00, -2.4308e+00, -7.0073e-02, -1.2062e+00,
          9.7808e+00,  3.7474e-01, -1.6097e+00, -1.2775e+00,  2.0579e-03],
        [ 8.3313e+00, -4.3187e-01,  9.6868e-01, -1.0124e+00, -1.4881e+00,
         -8.2926e-01, -9.4401e-01, -2.1595e+00, -1.2920e+00, -2.5268e+00],
        [ 5.9980e-01,  9.0251e+00,  1.8135e+00, -1.9506e+00, -1.2896e+00,
         -1.7555e+00, -1.1729e+00, -1.5549e+00, -1.8811e+00, -1.6535e+00],
        [ 1.8912e+00,  9.2911e+00,  2.1467e-01, -1.6547e+00, -2.0466e+00,
         -9.2959e-01, -8.3301e-

(tensor([[-1.4887, -1.0470, -1.6597, -0.9320, -1.6158, -0.2827, -1.8053,  0.6952,
         -1.9760,  9.5900],
        [-1.1368, -1.3520, -1.3034, 10.1026, -1.7426, -0.9050, -0.3417, -0.8518,
         -0.8131, -0.8238],
        [-1.0853, -1.1093, -0.9523, 10.2675, -1.7393, -1.3800, -0.5297, -1.2634,
         -0.4984, -0.9466],
        [ 1.2586,  9.8153, -0.9754, -1.7250, -1.5028, -1.1006, -0.5795, -1.4169,
         -2.2658, -1.4448],
        [-2.2618, -1.3460, -0.6174, -0.9460,  0.3348, -0.8258, -1.1438, -1.1891,
         10.2557, -0.5138],
        [ 1.9094, -0.9587,  8.8184, -2.0399, -1.6045, -1.4806,  0.0947, -2.7724,
         -2.6279, -1.8100],
        [-2.3868, -1.5421, -0.6863, -0.8919,  0.3914, -0.8988, -0.9540, -1.1107,
         10.1857, -0.3604],
        [-1.8836, -1.6288, -0.8974, -1.5099, -0.3080, -1.0574, -2.4056, -0.8707,
         -0.3731,  9.5559],
        [-1.1692, -1.0163, -0.9903, -1.0359, -0.1617, -1.1624,  9.9819, -1.2847,
         -1.5235, -1.2877],
        [-0.9933, 

(tensor([[-1.5366, -0.3606, -2.0997, -0.8806, -1.1954,  1.7122,  8.9582, -1.8829,
         -1.5799, -0.5022],
        [ 2.6927,  0.6676, -0.6266, -1.4124, -0.1892, -1.9434,  6.7878, -1.9752,
         -1.5851, -3.1089],
        [-0.9010, -1.0262, -0.9678, -1.0568, -1.7792, -1.2274, -1.2788, 10.5243,
         -1.2524, -1.2896],
        [-0.9262, -0.0257, -1.0688, -1.3719,  0.6619, -1.3988,  9.3539, -1.0433,
         -2.0449, -1.7969],
        [-0.9599, -1.0557, -2.1318, -0.1264, -0.9219,  9.8693, -0.1527, -1.2405,
         -1.1797, -0.9175],
        [-0.9547, -1.0553, -1.1905, -1.1082, -1.6907, -1.2947, -1.4697, 10.6753,
         -0.9705, -1.4102],
        [-0.9474, -0.6800, -1.3403, -1.2256, -0.3629, -0.8927,  9.9319, -1.2372,
         -1.5387, -1.4394],
        [ 9.3120,  0.5143,  0.2652, -1.5436, -1.5105, -1.5813, -1.9584, -1.4997,
         -1.7610, -1.6486],
        [-0.4621, 10.2239, -1.4575, -1.3024, -1.3724, -1.1676, -0.8084, -0.8348,
         -1.3644, -1.0819],
        [-1.3014, 

(tensor([[-1.2133, -1.1687, -1.3819, -1.5316, -1.2413, -1.3796, -1.4405, 10.6872,
         -1.2648, -0.8497],
        [-1.0182, -1.3530, -1.0530, -1.6567, -1.1294, -1.5392, -1.3212, 10.6082,
         -1.3204, -1.0647],
        [-1.0863, -0.6924, -1.0366, -0.8303, -0.7557, -1.2353,  9.9767, -0.7313,
         -1.9166, -1.2854],
        [-1.5194, -1.2538, -1.2661, -1.2659, -1.2628, -0.8031, -2.1596,  0.2564,
         -1.5986,  9.8377],
        [-0.9480, -1.1582, -1.0488, -1.1636, -1.4873, -1.2854, -1.4326, 10.6334,
         -1.1802, -1.1541],
        [-1.4058, -1.8171, -2.1316,  9.2514, -1.1004,  1.5767, -0.0616, -1.2734,
         -0.2305, -1.0904],
        [ 2.3060, -0.4821,  7.8641, -2.1017, -1.0826, -2.4668,  1.4426, -2.6971,
         -2.6687, -2.5083],
        [-1.5753, -0.9165, -1.4364, -0.5099, -1.7452, -0.9943, -1.5658, -0.0489,
         -1.9418,  9.7804],
        [ 3.9726, -1.4858, -2.2828,  0.5741, -0.7754,  5.5189, -1.7883, -2.6820,
         -1.7710,  0.3756],
        [-0.4504, 

(tensor([[-2.0530, -1.0906, -0.2338, -0.9932, -0.2692, -0.8033, -1.0036, -1.0195,
         10.1377, -0.5980],
        [ 1.0522, -0.0923,  2.8879, -1.7630, -0.6694, -2.1018,  7.9365, -2.5426,
         -2.3234, -3.0125],
        [-2.1189, -1.9316, -1.3541, -0.8916,  3.4518, -1.1173, -1.4798, -1.4163,
          9.2302, -1.2400],
        [ 9.2248, -0.0594, -0.7271, -1.3629, -1.1756, -0.6956, -1.8456, -1.1352,
         -1.8403, -1.3263],
        [-0.1841, -0.3418,  9.4542, -2.1770, -0.8781, -1.7767, -1.2639, -2.4188,
         -1.2277, -1.2477],
        [-0.4868, 10.1872, -1.5426, -1.3684, -1.2342, -1.0737, -1.0608, -0.9408,
         -1.3270, -0.9021],
        [-2.5442, -1.3133, -0.8468, -0.8259,  0.5555, -0.7271, -1.1143, -1.3189,
         10.1768, -0.2330],
        [ 5.1161, -0.8657,  6.9781, -1.2626, -0.3548, -2.9598, -1.7404, -1.7887,
         -2.8497, -2.9998],
        [ 9.1212, -0.1082,  2.1744, -1.7153, -1.0015, -2.3319, -1.6579, -2.1352,
         -2.4565, -2.3727],
        [-0.0301, 

(tensor([[-1.9848e-01, -8.8444e-01,  9.5250e+00, -1.7673e+00, -7.6818e-01,
         -1.7373e+00, -1.3248e+00, -2.4229e+00, -1.0837e+00, -1.3683e+00],
        [-9.9520e-01, -2.7108e+00,  2.8354e-01, -2.0459e+00,  8.6197e+00,
         -2.1391e+00,  5.6145e-02, -2.0378e+00,  1.9427e+00, -3.2558e+00],
        [ 9.4368e+00, -4.1783e-01, -2.9063e-02, -1.5555e+00, -1.2862e+00,
         -9.6910e-01, -1.8760e+00, -1.2734e+00, -1.6331e+00, -1.5416e+00],
        [-1.8886e+00, -1.4844e+00, -7.7471e-01, -1.3867e+00, -9.1491e-01,
         -1.2932e+00, -2.2842e+00,  6.9441e-01, -1.3860e+00,  9.6956e+00],
        [-1.4717e+00, -7.7986e-01, -3.8730e+00,  3.7408e+00, -1.0440e+00,
         -4.8132e-01,  3.2319e+00,  1.2859e-01, -3.8292e+00,  3.7315e+00],
        [-5.8801e-01, -1.1724e+00,  9.4240e+00, -1.7526e+00, -3.2238e-01,
         -1.8192e+00, -1.3592e+00, -2.2035e+00, -9.7547e-01, -1.3006e+00],
        [-8.9564e-01, -7.5977e-01, -1.3340e+00, -1.0174e+00, -1.7091e+00,
         -1.4854e+00, -1.2843e+

(tensor([[ 8.8856e-01, -4.7815e-01,  8.8298e+00, -2.5635e+00,  8.3936e-01,
         -2.3419e+00, -1.0781e+00, -2.9036e+00, -1.7335e+00, -2.6016e+00],
        [-6.9507e-01, -6.5954e-01, -1.1859e+00, -1.1520e+00, -1.5223e-02,
         -1.3180e+00,  9.7201e+00, -9.2248e-01, -1.8037e+00, -1.5599e+00],
        [-1.3701e+00, -1.1691e+00, -6.6208e-01, -1.9869e+00, -6.8168e-01,
         -1.5571e+00, -2.3734e+00,  2.7320e-01, -1.3595e+00,  9.5071e+00],
        [-4.2310e-01,  1.0223e+01, -1.3838e+00, -1.2912e+00, -1.2509e+00,
         -1.2353e+00, -1.0189e+00, -9.2518e-01, -1.3063e+00, -1.0309e+00],
        [ 1.3133e+00,  4.5884e-01,  6.7244e+00, -3.1802e+00,  3.7003e+00,
         -2.7951e+00, -3.3782e+00, -2.3951e+00, -8.6494e-01, -2.5743e+00],
        [-3.8788e-01,  1.0208e+01, -1.4235e+00, -1.4875e+00, -1.2372e+00,
         -1.1696e+00, -9.8378e-01, -7.8471e-01, -1.3767e+00, -1.1454e+00],
        [-3.2644e-03,  1.0070e+01, -1.5928e+00, -1.3664e+00, -1.1129e+00,
         -1.1247e+00, -1.1581e+

(tensor([[-1.1446e+00, -2.6165e+00, -8.6690e-01, -1.3411e+00,  9.8047e+00,
         -1.1214e+00, -5.3345e-01, -1.6961e+00, -4.8061e-01, -2.2585e+00],
        [ 9.1980e+00, -7.4570e-01,  1.6850e+00, -2.0596e+00, -4.8019e-01,
         -2.0360e+00, -1.8845e+00, -1.6615e+00, -2.1081e+00, -2.2543e+00],
        [-8.2989e-01,  1.0152e+01, -1.0848e+00, -1.1490e+00, -1.3946e+00,
         -1.3191e+00, -1.1540e+00, -8.0729e-01, -1.0290e+00, -1.0144e+00],
        [ 2.2634e+00, -4.5610e-01,  9.0365e+00, -1.6857e+00, -1.7782e+00,
         -1.8729e+00, -1.8103e+00, -2.5843e+00, -1.7227e+00, -1.7835e+00],
        [-1.3192e-01, -1.4392e+00, -1.0449e+00,  9.4257e+00, -1.3267e+00,
         -1.2651e+00,  6.8910e-01, -2.1608e+00, -1.0583e+00, -1.2436e+00],
        [ 9.5189e+00, -9.9669e-01,  9.6931e-01, -1.7073e+00, -8.3583e-01,
         -1.3322e+00, -1.9288e+00, -1.3732e+00, -2.3194e+00, -1.6953e+00],
        [-2.9343e-01, -8.4235e-01,  9.5130e+00, -2.0791e+00, -2.3930e-01,
         -1.7611e+00, -1.3758e+

(tensor([[-2.4410, -1.3562, -0.8323, -1.1275,  0.5399, -0.6922, -1.2985, -1.4954,
         10.0017,  0.2021],
        [ 9.1390, -1.0680,  0.4676, -1.4414, -0.2709, -1.3791, -1.9948, -1.1976,
         -2.1163, -1.7686],
        [ 9.2661, -0.5729,  0.2679, -1.6596, -0.7082, -1.3597, -1.4688, -1.5182,
         -1.9562, -1.9342],
        [-0.3031, 10.2006, -1.3942, -1.3993, -1.2094, -1.3968, -1.0893, -0.8082,
         -1.3966, -0.9223],
        [-0.9464, -0.9680, -1.4746, -0.5761, -1.4111,  9.7151, -0.2585, -0.8389,
         -1.3738, -0.5134],
        [-1.2833, -0.9971, -1.5728, 10.1722, -1.6836, -0.9456, -0.1358, -1.2584,
         -0.7464, -0.6176],
        [-0.7493, -2.2571, -1.2579, -1.2470,  9.7149, -1.2463, -0.7438, -1.2705,
         -1.1113, -2.1971],
        [-1.0422, -0.8573, -1.7210, -0.8187, -0.9808, -0.0910, -2.1240, -0.4666,
         -1.9402,  9.4402],
        [ 9.3549, -0.4262, -0.1840, -1.4726, -1.4005, -0.8805, -1.8479, -1.1372,
         -1.6861, -1.1857],
        [-0.3728, 

(tensor([[ 9.3740e+00, -7.7104e-01,  1.2466e+00, -1.7137e+00, -7.4745e-01,
         -1.7987e+00, -2.1136e+00, -1.4084e+00, -2.0702e+00, -1.9364e+00],
        [-8.8305e-01, -8.0288e-01, -3.7558e-01, -1.2448e+00, -3.0619e-01,
         -1.3585e+00,  9.8073e+00, -1.4646e+00, -1.5457e+00, -1.4967e+00],
        [-1.1898e+00, -2.0224e+00, -6.9240e-01, -1.5706e+00,  9.8266e+00,
         -1.5590e+00, -8.9162e-01, -1.4730e+00, -8.6418e-01, -2.2619e+00],
        [-1.5815e+00, -5.6795e-01, -2.6975e+00,  9.4639e+00, -2.9669e-01,
          1.2200e+00, -1.4778e+00, -1.7090e+00, -1.2243e+00, -2.9760e-01],
        [-1.2773e+00, -1.3060e+00, -1.4831e+00, -1.0988e+00, -1.7404e+00,
         -9.6977e-01, -1.2560e+00,  1.0712e+01, -8.2837e-01, -9.2897e-01],
        [-1.5853e+00, -1.6499e+00, -7.6104e-01, -1.4149e+00, -3.4591e-01,
         -1.3530e+00, -2.1852e+00, -6.2976e-01, -1.0472e+00,  9.6296e+00],
        [-1.6107e+00, -1.9382e+00, -1.5471e+00, -9.2300e-01, -7.1730e-01,
         -1.3062e+00, -2.3740e+

(tensor([[-1.2877,  0.9850, -2.2983, -0.4341, -2.1241,  0.9657, -2.2064, -0.5872,
         -2.3929,  8.9110],
        [-0.9259, -2.2677, -0.6921, -1.5063,  9.8306, -1.6145, -1.0433, -1.5373,
         -0.5874, -2.1259],
        [-2.3894, -1.0640, -0.9249, -1.1590,  1.5780, -1.0119, -1.6321, -1.6433,
         10.0131, -0.5019],
        [-1.3753, -0.5880, -2.0205, -0.5126, -1.1512,  9.5525, -0.4189, -1.7611,
         -1.2298,  0.4568],
        [-0.6190, 10.2295, -1.1168, -1.2946, -1.2561, -1.2996, -0.9911, -1.0409,
         -1.2644, -1.1100],
        [ 9.2542, -1.1158,  0.3712, -1.4268, -0.5930, -1.1120, -2.1657, -1.4266,
         -1.6304, -1.5630],
        [-1.1938, -1.3332, -1.6960,  9.7280, -1.3540, -1.1367,  1.7820, -1.3345,
         -0.9060, -1.2194],
        [-0.3053, 10.2091, -1.4036, -1.4182, -1.2883, -1.2144, -1.0696, -0.8538,
         -1.3681, -1.0480],
        [ 9.0188,  0.9833, -0.1584, -1.9663, -1.7388, -1.1307, -1.7734, -1.2058,
         -1.7255, -1.4007],
        [-0.4729, 

(tensor([[-1.3885e+00, -1.1254e+00, -6.8269e-01, -1.8450e+00, -1.4611e+00,
          4.1858e-01, -2.0438e+00,  2.2253e-01, -1.9452e+00,  9.4072e+00],
        [-6.7763e-01,  1.0059e+01, -1.4552e+00, -1.6458e+00, -4.3046e-01,
         -1.2581e+00, -1.2578e+00, -1.0002e+00, -1.2627e+00, -1.0266e+00],
        [-2.0398e+00, -1.3270e+00, -7.1280e-01, -1.0299e+00,  7.4629e-01,
         -8.3466e-01, -1.3146e+00, -1.3557e+00,  1.0187e+01, -8.1340e-01],
        [ 1.0352e+00,  9.2698e+00,  4.2406e-01, -2.3158e+00, -7.7599e-01,
         -1.4148e+00, -3.4226e-01, -1.7083e+00, -2.5988e+00, -1.7222e+00],
        [-7.9485e-01, -8.3886e-01, -1.7672e+00,  9.9260e+00, -1.3574e+00,
         -8.9730e-01,  2.4721e-01, -1.9838e+00, -1.1878e+00, -8.7032e-01],
        [-1.1205e+00, -9.6692e-01, -1.5035e+00,  1.0421e+01, -1.7437e+00,
         -7.1208e-01, -5.1616e-01, -1.1540e+00, -6.7030e-01, -9.2969e-01],
        [-3.3651e-01, -1.9774e+00, -7.0855e-01, -1.3656e+00,  9.3097e+00,
         -1.8376e+00, -1.0489e+

(tensor([[-8.7449e-01, -9.8956e-01, -7.6721e-01, -8.9724e-01, -2.9600e-01,
         -1.1650e+00,  9.9442e+00, -1.4072e+00, -1.6560e+00, -1.6307e+00],
        [-1.1755e+00, -4.6907e-01, -1.3642e+00,  2.8820e-01, -1.5228e+00,
         -1.6639e-01,  9.3723e+00, -1.7132e+00, -1.5064e+00, -9.3787e-01],
        [-3.2944e-01,  1.0158e+01, -1.6320e+00, -1.3771e+00, -1.1520e+00,
         -1.1284e+00, -9.1730e-01, -1.0003e+00, -1.5010e+00, -9.8297e-01],
        [ 1.1045e+00,  4.6750e-01,  3.1223e+00, -2.7936e+00,  7.0465e+00,
         -2.2081e+00, -2.6158e+00, -2.8465e+00, -4.2250e-01, -3.9241e+00],
        [-1.4032e+00, -1.6129e+00, -5.7232e-01, -1.2283e+00,  1.3158e+00,
         -1.0595e+00,  9.3436e+00, -1.4272e+00, -1.4214e+00, -1.9109e+00],
        [-1.3884e+00, -2.1797e+00, -2.4861e+00,  3.8374e+00, -8.2903e-01,
          8.2390e+00,  7.5952e-02, -1.1655e+00, -1.2050e+00, -9.4848e-01],
        [-9.3948e-01, -9.7598e-01, -1.3018e+00, -1.1064e+00, -1.5806e+00,
         -1.3780e+00, -1.3428e+

(tensor([[-2.2483, -1.5186, -0.3421, -0.9588,  0.1088, -0.8365, -0.8072, -1.1285,
         10.1467, -0.4768],
        [-0.4666, -1.1942, -1.3262, -0.9974, -1.2714,  9.5323, -0.7755, -1.3075,
         -1.4948,  0.2609],
        [-0.6865, -1.1926, -0.8229, -1.1751, -1.7228, -1.3162, -1.4378, 10.5165,
         -0.7480, -1.5671],
        [-1.0367, -1.5300,  1.8218, -2.2579,  7.7501, -2.3345, -2.4375, -2.0289,
          2.0666, -2.3559],
        [-1.4916, -1.1091, -1.0129, -1.1207, -1.2405, -1.2283, -2.0613,  0.4728,
         -1.9758,  9.6872],
        [-2.2210, -1.1818, -0.3036, -0.9949,  0.6388, -1.0884, -1.2677, -1.3346,
         10.1165, -0.7700],
        [-1.2092, -1.5197, -2.1363, -0.4809, -1.3351, -1.3491, -1.6395, 10.5172,
         -1.0058, -0.4547],
        [-0.2218,  1.1324, -0.0138, -1.8326, -2.5046, -1.1532, -1.5526,  9.4708,
         -1.6365, -1.6368],
        [-1.3890, -0.8559, -1.1116, -0.9494,  2.5140, -1.7831,  8.1104, -2.0471,
         -1.1129, -1.9976],
        [ 0.7327, 

(tensor([[-1.9167, -2.1366, -1.2455, -1.2976, -0.8144, -1.0102, -2.3806,  2.5915,
         -1.5972,  8.6769],
        [ 0.4954,  0.0584,  9.3838, -2.2214, -0.7748, -2.0618, -1.8087, -2.3951,
         -1.5050, -1.3846],
        [ 0.6024, -1.1261,  9.4684, -1.9139, -0.4564, -2.0325, -1.0471, -2.7113,
         -1.6024, -1.6320],
        [-1.7881, -0.5281, -1.9772,  9.8267, -0.7524, -0.2064, -1.3448, -1.6810,
         -0.7850,  0.0348],
        [-2.0076, -1.1926, -0.7730, -0.9663, -0.1902, -0.6398, -0.9727, -0.9040,
          9.9454, -0.3511],
        [-0.0284, -0.2064, -1.8841, -0.2569, -1.1887,  9.7473, -0.5492, -1.9217,
         -1.5864, -0.9713],
        [-0.8347, -0.7178, -0.6327, -1.1629, -1.8301, -1.3077, -1.4310, 10.4575,
         -1.2348, -1.4011],
        [-1.3678, -2.1946, -0.6598, -1.4709,  9.6830, -1.4259, -1.1111, -1.3012,
         -0.4526, -2.0315],
        [-0.5090, 10.1947, -1.5814, -1.2845, -1.2820, -1.1033, -0.9091, -0.8197,
         -1.2867, -1.0423],
        [ 1.9525, 

(tensor([[-6.4394e-01,  1.0228e+01, -1.1641e+00, -1.5704e+00, -1.2969e+00,
         -1.0354e+00, -7.6597e-01, -8.5753e-01, -1.3471e+00, -1.1336e+00],
        [-2.2265e+00, -1.4109e+00,  1.3042e+00, -1.6806e+00,  2.4601e+00,
         -9.8659e-01,  7.0647e+00, -2.5703e+00, -1.4654e+00, -1.1434e+00],
        [ 2.2004e+00,  6.4574e-01,  7.1853e+00, -3.1151e+00, -9.6071e-01,
         -2.7622e+00,  2.2650e+00, -2.9206e+00, -2.7262e+00, -2.5146e+00],
        [ 9.4000e+00, -4.2309e-01, -5.8716e-03, -1.7074e+00, -1.0898e+00,
         -1.1697e+00, -1.7842e+00, -1.2715e+00, -1.5834e+00, -1.4881e+00],
        [ 9.3612e+00, -6.4487e-01,  1.5047e+00, -1.4093e+00, -1.4285e+00,
         -1.7673e+00, -2.2806e+00, -1.7095e+00, -1.3718e+00, -1.8127e+00],
        [-2.1305e+00, -1.0308e+00, -2.6844e-01, -9.9360e-01, -2.5596e-02,
         -1.0423e+00, -1.1712e+00, -1.0227e+00,  1.0135e+01, -5.7768e-01],
        [ 9.1700e+00, -2.8977e-01,  1.0372e+00, -2.1686e+00, -2.0008e+00,
         -7.8297e-01, -2.0414e+

(tensor([[-2.1006, -1.5191, -0.8480, -1.0196,  1.6498, -0.9471, -1.4693, -1.6748,
          9.9790, -0.8437],
        [-1.1670, -1.4098, -1.6646, -0.9354, -1.3017, -1.0459, -1.4691, 10.6483,
         -1.2532, -0.8687],
        [-1.5628, -1.2405,  1.0840, -3.4690,  6.1824,  1.3719,  2.8659, -2.9820,
         -0.0662, -3.1030],
        [-1.8383, -1.3499,  0.5646, -1.5951,  1.1386, -1.4774, -1.6429, -1.7360,
          9.6843, -0.5860],
        [-0.9353, -1.0529, -0.8820, -1.2575, -1.5616, -1.3946, -1.4649, 10.6119,
         -0.9080, -1.4084],
        [ 0.3246, -1.6156, -0.4387, -2.0785, -1.0786, -0.7260, -2.7938, -0.2787,
         -1.7952,  8.6295],
        [ 1.7839,  9.3085, -0.1131, -2.4466, -0.8796, -1.6661, -1.8909, -1.6054,
         -2.1056, -1.1343],
        [ 8.8980, -0.7872,  3.0152, -2.0254, -0.8861, -2.3778, -2.3954, -2.0871,
         -1.8915, -2.3724],
        [ 0.4374, -0.3926,  9.5037, -2.1517, -1.1475, -1.8185, -1.2484, -2.5264,
         -1.5817, -1.2926],
        [ 9.0649, 

(tensor([[-0.4157, 10.1534, -1.4280, -1.5356, -1.0170, -1.3432, -1.2117, -0.7212,
         -1.2256, -1.0551],
        [-0.6349, -1.1226, -0.3058, -0.8886,  0.7320, -1.5920,  9.4559, -1.2003,
         -2.4225, -1.6671],
        [-1.2772, -0.9444, -1.6251, 10.0739, -1.9852, -0.4952, -0.5101, -1.3273,
         -0.3746, -0.9455],
        [-1.1531, -2.4726,  2.0029, -2.5502,  8.4913, -1.2340, -0.8857, -2.1733,
          0.8696, -3.3332],
        [ 0.3356,  4.7037, -1.1851,  3.2679, -1.4484,  2.0227, -1.4516, -2.5522,
         -1.3342, -2.1277],
        [ 0.0860,  0.2447,  8.8278, -1.8489,  0.5804, -2.0875, -1.6414, -2.6461,
         -1.7926, -2.4923],
        [-1.4285, -1.5537, -1.2724,  0.6990, -0.7276, -0.7285,  9.5539, -1.3384,
         -1.3530, -1.3389],
        [-1.9339, -1.7335, -1.5516, -0.9988, -1.3741,  0.7507, -1.8395, -1.1256,
         -1.1283,  9.5362],
        [-1.7853, -1.0291, -2.1249,  0.0166, -1.0161, -0.8645, -1.8536, -0.1939,
         -1.6272,  9.7279],
        [ 2.4884, 

(tensor([[-7.6086e-01, -2.1610e+00, -1.2738e+00, -1.3609e+00,  9.7115e+00,
         -1.2971e+00, -7.3047e-01, -1.4464e+00, -9.1357e-01, -2.3096e+00],
        [-1.9900e+00, -1.0250e+00, -3.9662e-01, -1.0524e+00,  2.2338e-01,
         -8.0412e-01, -1.1642e+00, -1.3434e+00,  1.0177e+01, -6.6301e-01],
        [-1.2330e+00, -1.2897e+00, -1.2102e+00, -1.0895e+00, -1.8134e+00,
         -1.1459e+00, -1.2454e+00,  1.0696e+01, -8.3820e-01, -1.2042e+00],
        [-1.3424e+00, -8.3535e-01, -1.1993e+00, -1.5981e+00, -1.4510e+00,
         -5.6324e-01, -1.8926e+00,  1.1223e-01, -1.7563e+00,  9.8125e+00],
        [-1.3998e+00, -9.0296e-01, -1.3105e+00,  1.0322e+01, -1.9875e+00,
         -8.0548e-01, -7.8721e-01, -1.1791e+00, -8.2732e-01, -4.4796e-01],
        [-1.7647e+00, -1.3917e+00, -2.1679e+00,  9.9022e+00, -1.1868e+00,
         -8.3065e-03, -4.0286e-01, -1.5630e+00, -7.9647e-01,  1.2228e-01],
        [-1.0635e+00, -9.4195e-01, -1.9419e+00, -3.0231e-01, -9.7842e-01,
          9.9163e+00,  1.6078e-

KeyboardInterrupt: 

## 测试模型

In [None]:
test_gen = DataGen(test_path, tokenizer=bert_tokenizer)
test_dataloader = data_gen(batch_size=32)

for data in test_dataloader:
    input_data, input_mask, input_labels = data
    print(input_data)
    print(input_mask)
    print(input_labels)
    break

In [None]:
for batch in test_dataloader:

    batch = tuple(t.to(device) for t in batch)

    b_input_ids, b_input_mask, b_labels = batch

    with torch.no_grad():

        outputs = model(b_input_ids,
                        token_type_ids=None,
                        attention_mask=b_input_mask)

    print(outputs)
    logits = outputs[0]

    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    tmp_eval_accuracy = flat_accuracy(logits, label_ids)

    eval_accuracy += tmp_eval_accuracy

    nb_eval_steps += 1

print("  Accuracy: {0:.2f}".format(eval_accuracy / nb_eval_steps))
print("  Validation took: {:}".format(format_time(time.time() - t0)))

# ERINE模型分类

In [69]:
ernie_tokenizer = BertTokenizer.from_pretrained(
    '../../H/models/huggingface/ERNIE/')

# model = BertForSequenceClassification.from_pretrained(
#     '../../H/models/huggingface/ERNIE/',  # 本地文件载入
#     num_labels=10,
#     output_attentions=False,
#     output_hidden_states=False,
# )
# for param in model.parameters():
#     param.requires_grad = True
    
# model.cuda()

In [70]:
len(ernie_tokenizer)

17964

In [71]:
print(bert_tokenizer.tokenize(train_data[9999]))

['温', '家', '宝', '：', '去', '年', '进', '出', '口', '总', '额', '2', '.', '2', '万', '亿', '美', '元']


In [72]:
print(ernie_tokenizer.tokenize(train_data[9999]))

['温', '家', '宝', '：', '去', '年', '进', '出', '口', '总', '额', '2', '.', '2', '万', '亿', '美', '元']


In [57]:
len(ernie_tokenizer)

17964