In [1]:
import pandas as pd

def load_lcqmc():
    '''LCQMC文本匹配数据集
    '''
    train = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.train.data.zip', 
            sep='\t', names=['query1', 'query2', 'label'])

    valid = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.valid.data.zip', 
            sep='\t', names=['query1', 'query2', 'label'])

    test = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.test.data.zip', 
            sep='\t', names=['query1', 'query2', 'label'])

    return train, valid, test

In [2]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [3]:
train,valid,test = load_lcqmc()

In [27]:
test

Unnamed: 0,query1,query2,label
0,谁有狂三这张高清的,这张高清图，谁有,0
1,英雄联盟什么英雄最好,英雄联盟最好英雄是什么,1
2,这是什么意思，被蹭网吗,我也是醉了，这是什么意思,0
3,现在有什么动画片好看呢？,现在有什么好看的动画片吗？,1
4,请问晶达电子厂现在的工资待遇怎么样要求有哪些,三星电子厂工资待遇怎么样啊,0
...,...,...,...
12495,微店怎么开？怎么做代理？,微店怎样代理,1
12496,小学科学三年级上,小学三年级科学,0
12497,冬眠是什么意思？,冬眠的意思是什么,1
12498,天猫有假货吗,天猫卖假货吗,0


In [4]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
import pandas as pd
import random
import re

In [10]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_encoding = tokenizer(list(train["query1"]), list(train["query2"]), 
                           truncation=True, padding=True, max_length=100)
valid_encoding = tokenizer(list(valid["query1"]), list(valid["query2"]), 
                          truncation=True, padding=True, max_length=100)
test_encoding = tokenizer(list(test["query1"]), list(test["query2"]), 
                          truncation=True, padding=True, max_length=100)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


In [22]:
class BertDataset(Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels
        
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels is not None:
            item['label'] = torch.tensor(int(self.labels[idx]))
        return item
    
    def __len__(self):
        return len(self.encodings['input_ids'])

In [28]:
train_dataset = BertDataset(train_encoding, list(train["label"]))
valid_dataset = BertDataset(valid_encoding, list(valid["label"]))
test_dataset = BertDataset(test_encoding, list(test["label"]))

In [29]:
# 精度计算
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [31]:
from transformers import BertForNextSentencePrediction, AdamW, get_linear_schedule_with_warmup
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)
optim = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_loader) * 1
scheduler = get_linear_schedule_with_warmup(optim, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [37]:
# 训练函数
def train():
    model.train()
    total_train_loss = 0
    iter_num = 0
    total_iter = len(train_loader)
    best_accuracy = 0
    for batch in tqdm(train_loader):
        # 正向传播
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        total_train_loss += loss.item()
        
        # 反向梯度信息
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # 参数更新
        optim.step()
        scheduler.step()

        iter_num += 1
        if(iter_num % 100==0):
            print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))
        if(iter_num % 2000==0):
            accuracy = validation()
            if accuracy > best_accuracy:
                model.save_pretrained('G:\\deep_learning\\Coggle\\202301\\models\\bert\\'+'best_model')
                best_accuracy = accuracy
            model.train()
        
    print("Epoch: %d, Average training loss: %.4f"%(epoch, total_train_loss/len(train_loader)))

def validation():
    print("start validating...")
    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    for batch in tqdm(valid_dataloader):
        with torch.no_grad():
            # 正常传播
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        loss = outputs[0]
        logits = outputs[1]

        total_eval_loss += loss.item()
        logits = logits.detach().cpu().numpy()
        label_ids = labels.to('cpu').numpy()
        total_eval_accuracy += flat_accuracy(logits, label_ids)
        
    avg_val_accuracy = total_eval_accuracy / len(valid_dataloader)
    print("Accuracy: %.4f" % (avg_val_accuracy))
    print("Average testing loss: %.4f"%(total_eval_loss/len(valid_dataloader)))
    print("-------------------------------")
    return avg_val_accuracy
    
for epoch in range(2):
    print("------------Epoch: %d ----------------" % epoch)
    train()

------------Epoch: 0 ----------------


  0%|          | 0/14923 [00:00<?, ?it/s]

epoth: 0, iter_num: 100, loss: 0.4770, 0.67%
epoth: 0, iter_num: 200, loss: 0.3446, 1.34%
epoth: 0, iter_num: 300, loss: 0.5103, 2.01%
epoth: 0, iter_num: 400, loss: 0.5488, 2.68%
epoth: 0, iter_num: 500, loss: 0.2589, 3.35%
epoth: 0, iter_num: 600, loss: 0.3024, 4.02%
epoth: 0, iter_num: 700, loss: 0.4811, 4.69%
epoth: 0, iter_num: 800, loss: 0.2117, 5.36%
epoth: 0, iter_num: 900, loss: 0.3745, 6.03%
epoth: 0, iter_num: 1000, loss: 0.7474, 6.70%
epoth: 0, iter_num: 1100, loss: 0.3683, 7.37%
epoth: 0, iter_num: 1200, loss: 0.4584, 8.04%
epoth: 0, iter_num: 1300, loss: 0.6007, 8.71%
epoth: 0, iter_num: 1400, loss: 0.4657, 9.38%
epoth: 0, iter_num: 1500, loss: 0.5084, 10.05%
epoth: 0, iter_num: 1600, loss: 0.3515, 10.72%
epoth: 0, iter_num: 1700, loss: 0.6481, 11.39%
epoth: 0, iter_num: 1800, loss: 0.2248, 12.06%
epoth: 0, iter_num: 1900, loss: 0.2088, 12.73%
epoth: 0, iter_num: 2000, loss: 0.4435, 13.40%
start validating...


  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6326
Average testing loss: 0.7390
-------------------------------
epoth: 0, iter_num: 2100, loss: 0.5067, 14.07%
epoth: 0, iter_num: 2200, loss: 0.3252, 14.74%
epoth: 0, iter_num: 2300, loss: 0.2887, 15.41%
epoth: 0, iter_num: 2400, loss: 0.4167, 16.08%
epoth: 0, iter_num: 2500, loss: 0.3514, 16.75%
epoth: 0, iter_num: 2600, loss: 0.6134, 17.42%
epoth: 0, iter_num: 2700, loss: 0.3206, 18.09%
epoth: 0, iter_num: 2800, loss: 0.5279, 18.76%
epoth: 0, iter_num: 2900, loss: 0.4058, 19.43%
epoth: 0, iter_num: 3000, loss: 0.6448, 20.10%
epoth: 0, iter_num: 3100, loss: 0.2499, 20.77%
epoth: 0, iter_num: 3200, loss: 0.3799, 21.44%
epoth: 0, iter_num: 3300, loss: 0.3192, 22.11%
epoth: 0, iter_num: 3400, loss: 0.5416, 22.78%
epoth: 0, iter_num: 3500, loss: 0.5221, 23.45%
epoth: 0, iter_num: 3600, loss: 0.5171, 24.12%
epoth: 0, iter_num: 3700, loss: 0.5124, 24.79%
epoth: 0, iter_num: 3800, loss: 0.2734, 25.46%
epoth: 0, iter_num: 3900, loss: 0.2892, 26.13%
epoth: 0, iter_num: 4000, los

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6386
Average testing loss: 0.6705
-------------------------------
epoth: 0, iter_num: 4100, loss: 0.4131, 27.47%
epoth: 0, iter_num: 4200, loss: 1.1143, 28.14%
epoth: 0, iter_num: 4300, loss: 0.5278, 28.81%
epoth: 0, iter_num: 4400, loss: 0.5171, 29.48%
epoth: 0, iter_num: 4500, loss: 0.2128, 30.15%
epoth: 0, iter_num: 4600, loss: 0.4496, 30.82%
epoth: 0, iter_num: 4700, loss: 0.3541, 31.50%
epoth: 0, iter_num: 4800, loss: 0.6367, 32.17%
epoth: 0, iter_num: 4900, loss: 0.5333, 32.84%
epoth: 0, iter_num: 5000, loss: 0.4099, 33.51%
epoth: 0, iter_num: 5100, loss: 0.5316, 34.18%
epoth: 0, iter_num: 5200, loss: 0.4185, 34.85%
epoth: 0, iter_num: 5300, loss: 0.4159, 35.52%
epoth: 0, iter_num: 5400, loss: 0.3114, 36.19%
epoth: 0, iter_num: 5500, loss: 0.1921, 36.86%
epoth: 0, iter_num: 5600, loss: 0.3504, 37.53%
epoth: 0, iter_num: 5700, loss: 0.3621, 38.20%
epoth: 0, iter_num: 5800, loss: 0.3592, 38.87%
epoth: 0, iter_num: 5900, loss: 0.4971, 39.54%
epoth: 0, iter_num: 6000, los

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6525
Average testing loss: 0.6698
-------------------------------
epoth: 0, iter_num: 6100, loss: 0.4936, 40.88%
epoth: 0, iter_num: 6200, loss: 0.4339, 41.55%
epoth: 0, iter_num: 6300, loss: 0.5967, 42.22%
epoth: 0, iter_num: 6400, loss: 0.4963, 42.89%
epoth: 0, iter_num: 6500, loss: 0.4942, 43.56%
epoth: 0, iter_num: 6600, loss: 0.3824, 44.23%
epoth: 0, iter_num: 6700, loss: 0.5544, 44.90%
epoth: 0, iter_num: 6800, loss: 0.4935, 45.57%
epoth: 0, iter_num: 6900, loss: 0.3454, 46.24%
epoth: 0, iter_num: 7000, loss: 0.5122, 46.91%
epoth: 0, iter_num: 7100, loss: 0.7890, 47.58%
epoth: 0, iter_num: 7200, loss: 0.3497, 48.25%
epoth: 0, iter_num: 7300, loss: 0.3689, 48.92%
epoth: 0, iter_num: 7400, loss: 0.3096, 49.59%
epoth: 0, iter_num: 7500, loss: 0.3764, 50.26%
epoth: 0, iter_num: 7600, loss: 0.3457, 50.93%
epoth: 0, iter_num: 7700, loss: 0.3192, 51.60%
epoth: 0, iter_num: 7800, loss: 0.5126, 52.27%
epoth: 0, iter_num: 7900, loss: 0.5002, 52.94%
epoth: 0, iter_num: 8000, los

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6665
Average testing loss: 0.6551
-------------------------------
epoth: 0, iter_num: 8100, loss: 0.6251, 54.28%
epoth: 0, iter_num: 8200, loss: 0.4665, 54.95%
epoth: 0, iter_num: 8300, loss: 0.1926, 55.62%
epoth: 0, iter_num: 8400, loss: 0.2826, 56.29%
epoth: 0, iter_num: 8500, loss: 0.4603, 56.96%
epoth: 0, iter_num: 8600, loss: 0.4957, 57.63%
epoth: 0, iter_num: 8700, loss: 0.6433, 58.30%
epoth: 0, iter_num: 8800, loss: 0.3475, 58.97%
epoth: 0, iter_num: 8900, loss: 0.2515, 59.64%
epoth: 0, iter_num: 9000, loss: 0.8815, 60.31%
epoth: 0, iter_num: 9100, loss: 0.7166, 60.98%
epoth: 0, iter_num: 9200, loss: 0.5350, 61.65%
epoth: 0, iter_num: 9300, loss: 0.8240, 62.32%
epoth: 0, iter_num: 9400, loss: 0.5438, 62.99%
epoth: 0, iter_num: 9500, loss: 0.5165, 63.66%
epoth: 0, iter_num: 9600, loss: 0.2897, 64.33%
epoth: 0, iter_num: 9700, loss: 0.3862, 65.00%
epoth: 0, iter_num: 9800, loss: 0.5518, 65.67%
epoth: 0, iter_num: 9900, loss: 0.5883, 66.34%
epoth: 0, iter_num: 10000, lo

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6600
Average testing loss: 0.7295
-------------------------------
epoth: 0, iter_num: 10100, loss: 0.4484, 67.68%
epoth: 0, iter_num: 10200, loss: 0.3196, 68.35%
epoth: 0, iter_num: 10300, loss: 0.3210, 69.02%
epoth: 0, iter_num: 10400, loss: 0.4871, 69.69%
epoth: 0, iter_num: 10500, loss: 0.3885, 70.36%
epoth: 0, iter_num: 10600, loss: 0.5365, 71.03%
epoth: 0, iter_num: 10700, loss: 0.4269, 71.70%
epoth: 0, iter_num: 10800, loss: 0.2700, 72.37%
epoth: 0, iter_num: 10900, loss: 0.5478, 73.04%
epoth: 0, iter_num: 11000, loss: 0.5074, 73.71%
epoth: 0, iter_num: 11100, loss: 0.4258, 74.38%
epoth: 0, iter_num: 11200, loss: 0.4690, 75.05%
epoth: 0, iter_num: 11300, loss: 0.5836, 75.72%
epoth: 0, iter_num: 11400, loss: 0.4232, 76.39%
epoth: 0, iter_num: 11500, loss: 0.2251, 77.06%
epoth: 0, iter_num: 11600, loss: 0.4410, 77.73%
epoth: 0, iter_num: 11700, loss: 0.6511, 78.40%
epoth: 0, iter_num: 11800, loss: 0.3957, 79.07%
epoth: 0, iter_num: 11900, loss: 0.2604, 79.74%
epoth: 0, 

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6659
Average testing loss: 0.6832
-------------------------------
epoth: 0, iter_num: 12100, loss: 0.4461, 81.08%
epoth: 0, iter_num: 12200, loss: 0.6147, 81.75%
epoth: 0, iter_num: 12300, loss: 0.2431, 82.42%
epoth: 0, iter_num: 12400, loss: 0.7061, 83.09%
epoth: 0, iter_num: 12500, loss: 0.3306, 83.76%
epoth: 0, iter_num: 12600, loss: 0.1999, 84.43%
epoth: 0, iter_num: 12700, loss: 0.4563, 85.10%
epoth: 0, iter_num: 12800, loss: 0.6702, 85.77%
epoth: 0, iter_num: 12900, loss: 0.3742, 86.44%
epoth: 0, iter_num: 13000, loss: 0.4320, 87.11%
epoth: 0, iter_num: 13100, loss: 0.5332, 87.78%
epoth: 0, iter_num: 13200, loss: 0.3181, 88.45%
epoth: 0, iter_num: 13300, loss: 0.6247, 89.12%
epoth: 0, iter_num: 13400, loss: 0.6180, 89.79%
epoth: 0, iter_num: 13500, loss: 0.4005, 90.46%
epoth: 0, iter_num: 13600, loss: 0.3916, 91.13%
epoth: 0, iter_num: 13700, loss: 0.5899, 91.80%
epoth: 0, iter_num: 13800, loss: 0.3105, 92.47%
epoth: 0, iter_num: 13900, loss: 0.1740, 93.14%
epoth: 0, 

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6607
Average testing loss: 0.6923
-------------------------------
epoth: 0, iter_num: 14100, loss: 0.4624, 94.49%
epoth: 0, iter_num: 14200, loss: 0.2329, 95.16%
epoth: 0, iter_num: 14300, loss: 0.4855, 95.83%
epoth: 0, iter_num: 14400, loss: 0.2863, 96.50%
epoth: 0, iter_num: 14500, loss: 0.4529, 97.17%
epoth: 0, iter_num: 14600, loss: 0.1956, 97.84%
epoth: 0, iter_num: 14700, loss: 0.4693, 98.51%
epoth: 0, iter_num: 14800, loss: 0.1976, 99.18%
epoth: 0, iter_num: 14900, loss: 0.5175, 99.85%
Epoch: 0, Average training loss: 0.4563
------------Epoch: 1 ----------------


  0%|          | 0/14923 [00:00<?, ?it/s]

epoth: 1, iter_num: 100, loss: 0.4583, 0.67%
epoth: 1, iter_num: 200, loss: 0.1788, 1.34%
epoth: 1, iter_num: 300, loss: 0.4183, 2.01%
epoth: 1, iter_num: 400, loss: 0.5604, 2.68%
epoth: 1, iter_num: 500, loss: 0.7020, 3.35%
epoth: 1, iter_num: 600, loss: 0.5439, 4.02%
epoth: 1, iter_num: 700, loss: 0.2564, 4.69%
epoth: 1, iter_num: 800, loss: 0.3430, 5.36%
epoth: 1, iter_num: 900, loss: 0.5363, 6.03%
epoth: 1, iter_num: 1000, loss: 0.5805, 6.70%
epoth: 1, iter_num: 1100, loss: 0.3551, 7.37%
epoth: 1, iter_num: 1200, loss: 0.3437, 8.04%
epoth: 1, iter_num: 1300, loss: 0.4601, 8.71%
epoth: 1, iter_num: 1400, loss: 0.5684, 9.38%
epoth: 1, iter_num: 1500, loss: 0.3458, 10.05%
epoth: 1, iter_num: 1600, loss: 0.4877, 10.72%
epoth: 1, iter_num: 1700, loss: 0.3399, 11.39%
epoth: 1, iter_num: 1800, loss: 0.5647, 12.06%
epoth: 1, iter_num: 1900, loss: 0.3046, 12.73%
epoth: 1, iter_num: 2000, loss: 0.4021, 13.40%
start validating...


  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6623
Average testing loss: 0.6906
-------------------------------
epoth: 1, iter_num: 2100, loss: 0.5105, 14.07%
epoth: 1, iter_num: 2200, loss: 0.2062, 14.74%
epoth: 1, iter_num: 2300, loss: 0.6568, 15.41%
epoth: 1, iter_num: 2400, loss: 0.5795, 16.08%
epoth: 1, iter_num: 2500, loss: 0.3263, 16.75%
epoth: 1, iter_num: 2600, loss: 0.4126, 17.42%
epoth: 1, iter_num: 2700, loss: 0.3250, 18.09%
epoth: 1, iter_num: 2800, loss: 0.5757, 18.76%
epoth: 1, iter_num: 2900, loss: 0.2883, 19.43%
epoth: 1, iter_num: 3000, loss: 0.3048, 20.10%
epoth: 1, iter_num: 3100, loss: 0.6363, 20.77%
epoth: 1, iter_num: 3200, loss: 0.2108, 21.44%
epoth: 1, iter_num: 3300, loss: 0.4729, 22.11%
epoth: 1, iter_num: 3400, loss: 0.7148, 22.78%
epoth: 1, iter_num: 3500, loss: 0.3465, 23.45%
epoth: 1, iter_num: 3600, loss: 0.3879, 24.12%
epoth: 1, iter_num: 3700, loss: 0.5521, 24.79%
epoth: 1, iter_num: 3800, loss: 0.4358, 25.46%
epoth: 1, iter_num: 3900, loss: 0.4576, 26.13%
epoth: 1, iter_num: 4000, los

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6623
Average testing loss: 0.6907
-------------------------------
epoth: 1, iter_num: 4100, loss: 0.3103, 27.47%
epoth: 1, iter_num: 4200, loss: 0.5899, 28.14%
epoth: 1, iter_num: 4300, loss: 0.4618, 28.81%
epoth: 1, iter_num: 4400, loss: 0.4518, 29.48%
epoth: 1, iter_num: 4500, loss: 0.1667, 30.15%
epoth: 1, iter_num: 4600, loss: 0.7979, 30.82%
epoth: 1, iter_num: 4700, loss: 0.5755, 31.50%
epoth: 1, iter_num: 4800, loss: 0.4325, 32.17%
epoth: 1, iter_num: 4900, loss: 0.6195, 32.84%
epoth: 1, iter_num: 5000, loss: 0.1768, 33.51%
epoth: 1, iter_num: 5100, loss: 0.3167, 34.18%
epoth: 1, iter_num: 5200, loss: 0.4819, 34.85%
epoth: 1, iter_num: 5300, loss: 0.4183, 35.52%
epoth: 1, iter_num: 5400, loss: 0.4445, 36.19%
epoth: 1, iter_num: 5500, loss: 0.3647, 36.86%
epoth: 1, iter_num: 5600, loss: 0.3019, 37.53%
epoth: 1, iter_num: 5700, loss: 0.6440, 38.20%
epoth: 1, iter_num: 5800, loss: 0.6556, 38.87%
epoth: 1, iter_num: 5900, loss: 0.4550, 39.54%
epoth: 1, iter_num: 6000, los

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6623
Average testing loss: 0.6906
-------------------------------
epoth: 1, iter_num: 6100, loss: 0.2769, 40.88%
epoth: 1, iter_num: 6200, loss: 0.8064, 41.55%
epoth: 1, iter_num: 6300, loss: 0.3305, 42.22%
epoth: 1, iter_num: 6400, loss: 0.7773, 42.89%
epoth: 1, iter_num: 6500, loss: 0.4705, 43.56%
epoth: 1, iter_num: 6600, loss: 0.3473, 44.23%
epoth: 1, iter_num: 6700, loss: 0.4288, 44.90%
epoth: 1, iter_num: 6800, loss: 0.2947, 45.57%
epoth: 1, iter_num: 6900, loss: 0.3667, 46.24%
epoth: 1, iter_num: 7000, loss: 0.2845, 46.91%
epoth: 1, iter_num: 7100, loss: 0.4927, 47.58%
epoth: 1, iter_num: 7200, loss: 0.3802, 48.25%
epoth: 1, iter_num: 7300, loss: 0.3839, 48.92%
epoth: 1, iter_num: 7400, loss: 0.3932, 49.59%
epoth: 1, iter_num: 7500, loss: 0.3669, 50.26%
epoth: 1, iter_num: 7600, loss: 0.4254, 50.93%
epoth: 1, iter_num: 7700, loss: 0.2945, 51.60%
epoth: 1, iter_num: 7800, loss: 0.3260, 52.27%
epoth: 1, iter_num: 7900, loss: 0.6517, 52.94%
epoth: 1, iter_num: 8000, los

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6623
Average testing loss: 0.6909
-------------------------------
epoth: 1, iter_num: 8100, loss: 0.2764, 54.28%
epoth: 1, iter_num: 8200, loss: 0.3244, 54.95%
epoth: 1, iter_num: 8300, loss: 0.3972, 55.62%
epoth: 1, iter_num: 8400, loss: 0.2501, 56.29%
epoth: 1, iter_num: 8500, loss: 0.4541, 56.96%
epoth: 1, iter_num: 8600, loss: 0.3528, 57.63%
epoth: 1, iter_num: 8700, loss: 0.2288, 58.30%
epoth: 1, iter_num: 8800, loss: 0.4839, 58.97%
epoth: 1, iter_num: 8900, loss: 0.3158, 59.64%
epoth: 1, iter_num: 9000, loss: 0.3193, 60.31%
epoth: 1, iter_num: 9100, loss: 0.4661, 60.98%
epoth: 1, iter_num: 9200, loss: 0.4522, 61.65%
epoth: 1, iter_num: 9300, loss: 0.2321, 62.32%
epoth: 1, iter_num: 9400, loss: 0.2995, 62.99%
epoth: 1, iter_num: 9500, loss: 0.4912, 63.66%
epoth: 1, iter_num: 9600, loss: 0.8562, 64.33%
epoth: 1, iter_num: 9700, loss: 0.3232, 65.00%
epoth: 1, iter_num: 9800, loss: 0.3940, 65.67%
epoth: 1, iter_num: 9900, loss: 0.6963, 66.34%
epoth: 1, iter_num: 10000, lo

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6615
Average testing loss: 0.6912
-------------------------------
epoth: 1, iter_num: 10100, loss: 0.2770, 67.68%
epoth: 1, iter_num: 10200, loss: 0.6129, 68.35%
epoth: 1, iter_num: 10300, loss: 0.3995, 69.02%
epoth: 1, iter_num: 10400, loss: 0.3941, 69.69%
epoth: 1, iter_num: 10500, loss: 0.1320, 70.36%
epoth: 1, iter_num: 10600, loss: 0.5397, 71.03%
epoth: 1, iter_num: 10700, loss: 0.3295, 71.70%
epoth: 1, iter_num: 10800, loss: 0.3763, 72.37%
epoth: 1, iter_num: 10900, loss: 0.5937, 73.04%
epoth: 1, iter_num: 11000, loss: 0.3672, 73.71%
epoth: 1, iter_num: 11100, loss: 0.3007, 74.38%
epoth: 1, iter_num: 11200, loss: 0.5849, 75.05%
epoth: 1, iter_num: 11300, loss: 0.2548, 75.72%
epoth: 1, iter_num: 11400, loss: 0.4279, 76.39%
epoth: 1, iter_num: 11500, loss: 0.4964, 77.06%
epoth: 1, iter_num: 11600, loss: 0.4257, 77.73%
epoth: 1, iter_num: 11700, loss: 0.7598, 78.40%
epoth: 1, iter_num: 11800, loss: 0.3319, 79.07%
epoth: 1, iter_num: 11900, loss: 0.2025, 79.74%
epoth: 1, 

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6623
Average testing loss: 0.6905
-------------------------------
epoth: 1, iter_num: 12100, loss: 0.1745, 81.08%
epoth: 1, iter_num: 12200, loss: 0.3585, 81.75%
epoth: 1, iter_num: 12300, loss: 0.3248, 82.42%
epoth: 1, iter_num: 12400, loss: 0.2921, 83.09%
epoth: 1, iter_num: 12500, loss: 0.3009, 83.76%
epoth: 1, iter_num: 12600, loss: 0.3824, 84.43%
epoth: 1, iter_num: 12700, loss: 0.4865, 85.10%
epoth: 1, iter_num: 12800, loss: 0.3900, 85.77%
epoth: 1, iter_num: 12900, loss: 0.3494, 86.44%
epoth: 1, iter_num: 13000, loss: 0.7520, 87.11%
epoth: 1, iter_num: 13100, loss: 0.4261, 87.78%
epoth: 1, iter_num: 13200, loss: 0.4549, 88.45%
epoth: 1, iter_num: 13300, loss: 0.6242, 89.12%
epoth: 1, iter_num: 13400, loss: 0.5382, 89.79%
epoth: 1, iter_num: 13500, loss: 0.2922, 90.46%
epoth: 1, iter_num: 13600, loss: 0.4296, 91.13%
epoth: 1, iter_num: 13700, loss: 0.5860, 91.80%
epoth: 1, iter_num: 13800, loss: 0.2475, 92.47%
epoth: 1, iter_num: 13900, loss: 0.5761, 93.14%
epoth: 1, 

  0%|          | 0/551 [00:00<?, ?it/s]

Accuracy: 0.6623
Average testing loss: 0.6905
-------------------------------
epoth: 1, iter_num: 14100, loss: 0.3787, 94.49%
epoth: 1, iter_num: 14200, loss: 0.5444, 95.16%
epoth: 1, iter_num: 14300, loss: 0.5031, 95.83%
epoth: 1, iter_num: 14400, loss: 0.2669, 96.50%
epoth: 1, iter_num: 14500, loss: 0.2130, 97.17%
epoth: 1, iter_num: 14600, loss: 0.4135, 97.84%
epoth: 1, iter_num: 14700, loss: 0.4612, 98.51%
epoth: 1, iter_num: 14800, loss: 0.1833, 99.18%
epoth: 1, iter_num: 14900, loss: 0.3405, 99.85%
Epoch: 1, Average training loss: 0.4236


In [39]:
model = BertForNextSentencePrediction.from_pretrained('G:\\deep_learning\\Coggle\\202301\\models\\bert\\best_model')
model.to(device)
def test():
    print("start testing...")
    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    for batch in tqdm(test_dataloader):
        with torch.no_grad():
            # 正常传播
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        loss = outputs[0]
        logits = outputs[1]

        total_eval_loss += loss.item()
        logits = logits.detach().cpu().numpy()
        label_ids = labels.to('cpu').numpy()
        total_eval_accuracy += flat_accuracy(logits, label_ids)
        
    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)
    print("Accuracy: %.4f" % (avg_val_accuracy))
    print("Average testing loss: %.4f"%(total_eval_loss/len(test_dataloader)))
    print("-------------------------------")
    return avg_val_accuracy

test()

start testing...


  0%|          | 0/782 [00:00<?, ?it/s]

Accuracy: 0.6649
Average testing loss: 0.6602
-------------------------------


0.6648817135549873