# Longformer情感分类模型

## 任务目标
* 使用Longformer模型完成情感分类
* 测试任务的时间，做记录

## 业务背景
业务背景：期望通过评论信息得到该评论是好评还是差评。

数据介绍：文本和对应评价（1是好评，0是差评）,数据集：'seamew/ChnSentiCorp'

例子：

'房间太小。其他的都一般。。。。。。。。。'，0

'15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错'，1

## DataSet类加载数据
Pytorch 通过 Dataset 类和 DataLoader 类处理数据集和加载数据构建 batch。因此我们首先需要编写继承自 Dataset 类的自定义数据集用于组织样本和标签

In [5]:
import torch

from datasets import load_dataset
from datasets import load_from_disk


#定义数据集,方便后续模型读取批量数据。
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_type):
        self.data = self.load_data(data_type)
    
    # 核心要变的就是load_data这部分函数
    # 改造成适合自己任务的数据集
    def load_data(self, data_type):
        # 先加载本地数据集
        # tmp_dataset = load_dataset('csv',data_files='../data/ChnSentiCorp.csv', split = data_type)
        # tmp_dataset = load_dataset(path='seamew/ChnSentiCorp', split = data_type)
        # tmp_dataset = load_from_disk('./data/ChnSentiCorp')
        Data = {}
        for idx, line in enumerate(tmp_dataset):
            sample = line
            Data[idx] = sample
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    


In [2]:
import warnings
warnings.filterwarnings("ignore")

import torch

from datasets import load_dataset
from datasets import load_from_disk


dataset = load_from_disk('./data/ChnSentiCorp')

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
})

In [3]:
train_data = dataset['train']
valid_data = dataset['validation']
test_data = dataset['test']

In [4]:
for i in range(10):
    print(train_data[i])

{'text': '选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般', 'label': 1}
{'text': '15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错', 'label': 1}
{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0}
{'text': '1.接电源没有几分钟,电源适配器热的不行. 2.摄像头用不起来. 3.机盖的钢琴漆，手不能摸，一摸一个印. 4.硬盘分区不好办.', 'label': 0}
{'text': '今天才知道这书还有第6卷,真有点郁闷:为什么同一套书有两种版本呢?当当网是不是该跟出版社商量商量,单独出个第6卷,让我们的孩子不会有所遗憾。', 'label': 1}
{'text': '机器背面似乎被撕了张什么标签，残胶还在。但是又看不出是什么标签不见了，该有的都在，怪', 'label': 0}
{'text': '呵呵，虽然表皮看上去不错很精致，但是我还是能看得出来是盗的。但是里面的内容真的不错，我妈爱看，我自己也学着找一些穴位。', 'label': 0}
{'text': '这本书实在是太烂了,以前听浙大的老师说这本书怎么怎么不对,哪些地方都是误导的还不相信,终于买了一本看一下,发现真是~~~无语,这种书都写得出来', 'label': 0}
{'text': '地理位置佳，在市中心。酒店服务好、早餐品种丰富。我住的商务数码房电脑宽带速度满意,房间还算干净，离湖南路小吃街近。', 'label': 1}
{'text': '5.1期间在这住的，位置还可以，在市委市政府附近，要去商业区和步行街得打车，屋里有蚊子，虽然空间挺大，晚上熄灯后把窗帘拉上简直是伸手不见五指，很适合睡觉，但是会被该死的蚊子吵醒！打死了两只，第二天早上还是发现又没打死的，卫生间挺大，但是设备很老旧。', 'label': 1}


In [5]:
print(len(train_data))
print(len(valid_data))
print(len(test_data))

9600
1200
1200


## DataLoader
创建好数据集之后，就需要通过 DataLoader 库来按批 (batch) 加载数据，将样本组织成模型可以接受的输入格式。对于 NLP 任务来说，这个环节就是对一个 batch 中的句子（这里是“句子对”）按照预训练模型的要求进行编码（包括 Padding、截断等操作）通过在 DataLoader 中设置批处理函数 collate_fn 来实现。

In [8]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

checkpoint = "schen/longformer-chinese-base-4096"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# 批量处理函数
def collote_fn(batch_samples):
    batch_text= []
    batch_label = []
    for sample in batch_samples:
        batch_text.append(sample['text'])
        batch_label.append(int(sample['label']))
    X = tokenizer(
        batch_text, 
        padding="max_length",#padding=True, 
        max_length=4096,
        truncation=True, 
        return_tensors="pt"
    )
    y = torch.tensor(batch_label)
    return X, y

# batch_size=2跑完了，batch_size=4 跑不动
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=2, shuffle=True, collate_fn=collote_fn)
test_dataloader = DataLoader(test_data, batch_size=2, shuffle=True, collate_fn=collote_fn)

batch_X, batch_y = next(iter(train_dataloader))
print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
print('batch_y shape:', batch_y.shape)
print(batch_X)
print(batch_y)

batch_X shape: {'input_ids': torch.Size([4, 4096]), 'token_type_ids': torch.Size([4, 4096]), 'attention_mask': torch.Size([4, 4096])}
batch_y shape: torch.Size([4])
{'input_ids': tensor([[ 101, 2523, 2207,  ...,    0,    0,    0],
        [ 101, 1788, 4635,  ...,    0,    0,    0],
        [ 101, 2400, 3766,  ...,    0,    0,    0],
        [ 101,  679, 5052,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}
tensor([1, 1, 0, 1])


## 模型定义
预训练模型仅仅被用作编码器，模型中还会包含很多自定义的模块，因此本文采用自己编写 Pytorch 模型的方式来完成：首先利用 Transformers 库加载 Longformer 模型，然后接一个全连接层完成分类

In [9]:
from torch import nn
from transformers import AutoModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        # 这个可以拿到预训练模型最后一层的结果
        self.bert_encoder = AutoModel.from_pretrained(checkpoint)
        # 这里可以接分类层，输入768维，最后分为2个类别
        # 这里可以添加其他网络模型，提升效果
        self.classifier = nn.Linear(768, 2)

    def forward(self, x):
        bert_output = self.bert_encoder(**x)
        # 取最后一层的第一个，因为我们希望拿到的是整句话的一个语义
        cls_vectors = bert_output.last_hidden_state[:, 0]
        # 然后输送给分类的linear层
        logits = self.classifier(cls_vectors)
        return logits

# 如果显卡的话就使用显卡
model = NeuralNetwork().to(device)
print(model)

Using cpu device


Some weights of the model checkpoint at schen/longformer-chinese-base-4096 were not used when initializing BertModel: ['bert.encoder.layer.0.attention.self.key_global.weight', 'bert.encoder.layer.9.attention.self.key_global.weight', 'bert.encoder.layer.2.attention.self.value_global.weight', 'bert.encoder.layer.10.attention.self.value_global.bias', 'bert.encoder.layer.1.attention.self.key_global.bias', 'bert.encoder.layer.7.attention.self.value_global.bias', 'bert.encoder.layer.8.attention.self.query_global.weight', 'bert.encoder.layer.5.attention.self.query_global.bias', 'bert.encoder.layer.4.attention.self.value_global.weight', 'bert.encoder.layer.6.attention.self.query_global.bias', 'bert.encoder.layer.6.attention.self.value_global.bias', 'bert.encoder.layer.8.attention.self.key_global.bias', 'bert.encoder.layer.1.attention.self.query_global.weight', 'bert.encoder.layer.6.attention.self.value_global.weight', 'bert.encoder.layer.10.attention.self.key_global.weight', 'bert.encoder.laye

NeuralNetwork(
  (bert_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(4096, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

测试情况

In [10]:
total = sum(p.numel() for p in model.parameters())
print("total param:",total)

total param: 105021698


## 模型训练
我们将每一轮 Epoch 分为训练循环和验证/测试循环。在训练循环中计算损失、优化模型的参数，在验证/测试循环中评估模型的性能

In [8]:
from tqdm.auto import tqdm # 显示它的进度条，会更好看点

# 参数解释
# dataloader ： 批量数据的loader
# model : 定义的模型
# loss_fn ： 定义的损失函数
# optimizer ：优化器
# lr_scheduler ： 学习率根据步数会下降，动态变化的。如果用一个固定的学习率，其实是没有这种随着迭代次数下降的效果好的
# epoch ：训练的轮次
# total_loss ：整体loss的情况
def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1)*len(dataloader)
    
    model.train()
    for batch, (X, y) in enumerate(dataloader, start=1):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad() # 把之前的梯度都清掉
        loss.backward() # 向后传播
        optimizer.step() # 算完梯度下降之后更改参数
        lr_scheduler.step() # 对学习率进行调整

        total_loss += loss.item() # 统计一下整体的loss
        # batch=2时，每600输出一次loss
        # batch=4时，可以每300输出一次loss
        if batch % 600 == 599:
            print('[%d,%5d] running_loss:%.3f' % (epoch,batch+1,total_loss/(finish_batch_num + batch)))
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
        
        
        # running_loss += loss.item() 
        # 每600输出一次loss
        # if batch % 600 == 599:
            # print('[%d,%5d] running_loss:%.3f' % (epoch,batch+1,running_loss/600))
            # 每600一组数据跑完，清零
            # running_loss = 0.0
    return total_loss

def test_loop(dataloader, model, mode='Test'):
    assert mode in ['Valid', 'Test']
    size = len(dataloader.dataset)
    correct = 0

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    correct /= size
    print(f"{mode} Accuracy: {(100*correct):>0.1f}%\n")
    return correct

最后，将”训练循环”和”验证/测试循环”组合成 Epoch，就可以进行模型的训练和验证了。其实 Transformers 库同样实现了很多的优化器，相比 Pytorch 固定学习率的优化器，Transformers 库实现的优化器会随着训练过程逐步减小学习率（这通常会产生更好的效果）。 这个代码中，我们顺便还增加了torch.save(model.state_dict(),'xx')用于保存模型的参数。

## 定义模型 20221107

In [11]:
# 定义评价函数
from datasets import load_metric
def compute_metrics(eval_preds):
    metric = load_metric('glue','mrpc')
    logits,labels = eval_preds # 预测值和真实值
    predictions = np.argmax(logits,axis=-1)
    return metric.compute(predictions=predictions,references=labels)

In [12]:
# 定义训练器
from transformers import TrainingArguments,Trainer
# 初始化训练参数
args = TrainingArguments(output_dir='/home/chenli/pre_model/20221107/',evaluation_strategy='epoch') # evaluation_strategy表示每隔多长时间进行一次测试
args.num_train_epoch = 10 # 训练的轮次
args.learning_rate = 1e-5 # 学习率
args.weight_decay = 1e-2
args.per_device_eval_batch_size = 32
args.per_device_train_batch_size = 16

In [13]:
# 初始化训练器
trainer = Trainer(
    model = model,
    args = args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    compute_metrics=compute_metrics
)

In [14]:
# 评价模型
trainer.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `NeuralNetwork.forward` and have been ignored: text. If text are not expected by `NeuralNetwork.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1200
  Batch size = 32


TypeError: forward() got an unexpected keyword argument 'labels'

# 2022.10.22 训练结果
batch=2，4800个批次，
轮次=10，2022.10.23 结束

In [17]:
from transformers import AdamW, get_scheduler

learning_rate = 1e-5 # 定义学习率
epoch_num = 10 # 轮次定义

loss_fn = nn.CrossEntropyLoss() # 损失函数，交叉熵
optimizer = AdamW(model.parameters(), lr=learning_rate) # Adamw一个常用的优化器
lr_scheduler = get_scheduler(
    "linear",# 使用线性的方式，慢慢往下降
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_acc = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)
    valid_acc = test_loop(valid_dataloader, model, mode='Valid')
    if valid_acc > best_acc:
        best_acc = valid_acc
        print('saving new weights...\n')
        # 保存模型
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_acc_{(100*valid_acc):0.1f}_model_weights.bin')
print("Done!")

# 它会去保存最好的那个模型

Epoch 1/10
-------------------------------


loss: 0.365577:  12%|█▏        | 599/4800 [21:31<2:33:15,  2.19s/it]

[1,  600] running_loss:0.366


loss: 0.318970:  25%|██▍       | 1199/4800 [43:23<1:22:38,  1.38s/it]

[1, 1200] running_loss:0.319


loss: 0.304193:  37%|███▋      | 1799/4800 [1:04:21<2:26:44,  2.93s/it]

[1, 1800] running_loss:0.304


loss: 0.288217:  50%|████▉     | 2399/4800 [1:25:37<1:00:52,  1.52s/it]

[1, 2400] running_loss:0.288


loss: 0.282210:  62%|██████▏   | 2999/4800 [1:47:57<55:07,  1.84s/it]  

[1, 3000] running_loss:0.282


loss: 0.269241:  75%|███████▍  | 3599/4800 [2:09:54<39:58,  2.00s/it]  

[1, 3600] running_loss:0.269


loss: 0.259566:  87%|████████▋ | 4199/4800 [2:31:13<18:37,  1.86s/it]  

[1, 4200] running_loss:0.260


loss: 0.254165: 100%|█████████▉| 4799/4800 [2:52:57<00:01,  1.99s/it]

[1, 4800] running_loss:0.254


loss: 0.254329: 100%|██████████| 4800/4800 [2:52:59<00:00,  2.16s/it]


Valid Accuracy: 92.0%

saving new weights...

Epoch 2/10
-------------------------------


loss: 0.240535:  12%|█▏        | 599/4800 [21:39<2:12:56,  1.90s/it]

[2,  600] running_loss:0.241


loss: 0.230127:  25%|██▍       | 1199/4800 [43:13<2:25:36,  2.43s/it]

[2, 1200] running_loss:0.230


loss: 0.219585:  37%|███▋      | 1799/4800 [1:05:19<2:01:25,  2.43s/it]

[2, 1800] running_loss:0.220


loss: 0.211653:  50%|████▉     | 2399/4800 [1:26:17<1:37:48,  2.44s/it] 

[2, 2400] running_loss:0.212


loss: 0.205723:  62%|██████▏   | 2999/4800 [1:48:40<42:23,  1.41s/it]  

[2, 3000] running_loss:0.206


loss: 0.200433:  75%|███████▍  | 3599/4800 [2:10:40<40:44,  2.04s/it]  

[2, 3600] running_loss:0.200


loss: 0.195492:  87%|████████▋ | 4199/4800 [2:32:47<29:23,  2.93s/it]  

[2, 4200] running_loss:0.195


loss: 0.191354: 100%|█████████▉| 4799/4800 [2:54:55<00:01,  1.53s/it]  

[2, 4800] running_loss:0.191


loss: 0.191334: 100%|██████████| 4800/4800 [2:54:57<00:00,  2.19s/it]


Valid Accuracy: 94.2%

saving new weights...

Epoch 3/10
-------------------------------


loss: 0.183005:  12%|█▏        | 599/4800 [21:42<4:11:43,  3.60s/it]

[3,  600] running_loss:0.183


loss: 0.177080:  25%|██▍       | 1199/4800 [44:26<1:41:37,  1.69s/it]

[3, 1200] running_loss:0.177


loss: 0.171108:  37%|███▋      | 1799/4800 [1:07:27<1:31:00,  1.82s/it]

[3, 1800] running_loss:0.171


loss: 0.165494:  50%|████▉     | 2399/4800 [1:29:44<1:42:09,  2.55s/it] 

[3, 2400] running_loss:0.165


loss: 0.161120:  62%|██████▏   | 2999/4800 [1:51:36<45:49,  1.53s/it]  

[3, 3000] running_loss:0.161


loss: 0.157728:  75%|███████▍  | 3599/4800 [2:12:33<27:51,  1.39s/it]  

[3, 3600] running_loss:0.158


loss: 0.153593:  87%|████████▋ | 4199/4800 [2:33:38<20:05,  2.01s/it]  

[3, 4200] running_loss:0.154


loss: 0.150313: 100%|█████████▉| 4799/4800 [2:56:10<00:02,  2.26s/it]  

[3, 4800] running_loss:0.150


loss: 0.150303: 100%|██████████| 4800/4800 [2:56:11<00:00,  2.20s/it]


Valid Accuracy: 92.4%

Epoch 4/10
-------------------------------


loss: 0.145716:  12%|█▏        | 599/4800 [23:04<2:55:35,  2.51s/it] 

[4,  600] running_loss:0.146


loss: 0.141172:  25%|██▍       | 1199/4800 [45:15<1:23:01,  1.38s/it]

[4, 1200] running_loss:0.141


loss: 0.137557:  37%|███▋      | 1799/4800 [1:07:21<2:29:13,  2.98s/it] 

[4, 1800] running_loss:0.138


loss: 0.133891:  50%|████▉     | 2399/4800 [1:28:55<57:02,  1.43s/it]  

[4, 2400] running_loss:0.134


loss: 0.131202:  62%|██████▏   | 2999/4800 [1:50:46<1:06:53,  2.23s/it]

[4, 3000] running_loss:0.131


loss: 0.127869:  75%|███████▍  | 3599/4800 [2:13:27<45:54,  2.29s/it]  

[4, 3600] running_loss:0.128


loss: 0.125343:  87%|████████▋ | 4199/4800 [2:35:52<19:13,  1.92s/it]  

[4, 4200] running_loss:0.125


loss: 0.122892: 100%|█████████▉| 4799/4800 [2:56:52<00:01,  1.99s/it]

[4, 4800] running_loss:0.123


loss: 0.122982: 100%|██████████| 4800/4800 [2:56:56<00:00,  2.21s/it]


Valid Accuracy: 95.0%

saving new weights...

Epoch 5/10
-------------------------------


loss: 0.120257:  12%|█▏        | 599/4800 [21:40<1:57:21,  1.68s/it]

[5,  600] running_loss:0.120


loss: 0.117290:  25%|██▍       | 1199/4800 [43:43<2:44:20,  2.74s/it]

[5, 1200] running_loss:0.117


loss: 0.114889:  37%|███▋      | 1799/4800 [1:05:13<1:29:44,  1.79s/it]

[5, 1800] running_loss:0.115


loss: 0.112178:  50%|████▉     | 2399/4800 [1:27:16<1:13:30,  1.84s/it]

[5, 2400] running_loss:0.112


loss: 0.109821:  62%|██████▏   | 2999/4800 [1:50:19<52:19,  1.74s/it]  

[5, 3000] running_loss:0.110


loss: 0.107862:  75%|███████▍  | 3599/4800 [2:12:11<49:32,  2.48s/it]  

[5, 3600] running_loss:0.108


loss: 0.105820:  87%|████████▋ | 4199/4800 [2:34:17<20:40,  2.06s/it]  

[5, 4200] running_loss:0.106


loss: 0.104173: 100%|█████████▉| 4799/4800 [2:55:18<00:02,  2.40s/it]

[5, 4800] running_loss:0.104


loss: 0.104169: 100%|██████████| 4800/4800 [2:55:19<00:00,  2.19s/it]


Valid Accuracy: 94.5%

Epoch 6/10
-------------------------------


loss: 0.101836:  12%|█▏        | 599/4800 [22:17<2:12:53,  1.90s/it]

[6,  600] running_loss:0.102


loss: 0.100227:  25%|██▍       | 1199/4800 [44:07<2:39:52,  2.66s/it]

[6, 1200] running_loss:0.100


loss: 0.098330:  37%|███▋      | 1799/4800 [1:06:17<2:04:49,  2.50s/it]

[6, 1800] running_loss:0.098


loss: 0.096485:  50%|████▉     | 2399/4800 [1:28:31<1:07:51,  1.70s/it]

[6, 2400] running_loss:0.096


loss: 0.094963:  62%|██████▏   | 2999/4800 [1:50:11<2:31:45,  5.06s/it]

[6, 3000] running_loss:0.095


loss: 0.093296:  75%|███████▍  | 3599/4800 [2:11:12<45:21,  2.27s/it]  

[6, 3600] running_loss:0.093


loss: 0.091627:  87%|████████▋ | 4199/4800 [2:32:35<19:36,  1.96s/it]  

[6, 4200] running_loss:0.092


loss: 0.090182: 100%|█████████▉| 4799/4800 [2:54:52<00:01,  1.36s/it]  

[6, 4800] running_loss:0.090


loss: 0.090179: 100%|██████████| 4800/4800 [2:54:53<00:00,  2.19s/it]


Valid Accuracy: 93.8%

Epoch 7/10
-------------------------------


loss: 0.088651:  12%|█▏        | 599/4800 [21:17<1:35:53,  1.37s/it]

[7,  600] running_loss:0.089


loss: 0.087061:  25%|██▍       | 1199/4800 [43:04<4:24:20,  4.40s/it]

[7, 1200] running_loss:0.087


loss: 0.085574:  37%|███▋      | 1799/4800 [1:06:00<1:31:00,  1.82s/it] 

[7, 1800] running_loss:0.086


loss: 0.084052:  50%|████▉     | 2399/4800 [1:29:02<2:17:24,  3.43s/it]

[7, 2400] running_loss:0.084


loss: 0.082800:  62%|██████▏   | 2999/4800 [1:49:32<46:20,  1.54s/it]  

[7, 3000] running_loss:0.083


loss: 0.081548:  75%|███████▍  | 3599/4800 [2:10:25<43:32,  2.18s/it]  

[7, 3600] running_loss:0.082


loss: 0.080444:  87%|████████▋ | 4199/4800 [2:32:30<17:47,  1.78s/it]  

[7, 4200] running_loss:0.080


loss: 0.079125: 100%|█████████▉| 4799/4800 [2:54:40<00:01,  1.47s/it]

[7, 4800] running_loss:0.079


loss: 0.079122: 100%|██████████| 4800/4800 [2:54:43<00:00,  2.18s/it]


Valid Accuracy: 93.9%

Epoch 8/10
-------------------------------


loss: 0.077859:  12%|█▏        | 599/4800 [22:25<2:13:16,  1.90s/it]

[8,  600] running_loss:0.078


loss: 0.076668:  25%|██▍       | 1199/4800 [44:52<1:49:17,  1.82s/it]

[8, 1200] running_loss:0.077


loss: 0.075553:  37%|███▋      | 1799/4800 [1:05:48<1:38:34,  1.97s/it]

[8, 1800] running_loss:0.076


loss: 0.074469:  50%|████▉     | 2399/4800 [1:28:01<1:17:35,  1.94s/it]

[8, 2400] running_loss:0.074


loss: 0.073331:  62%|██████▏   | 2999/4800 [1:50:25<1:48:04,  3.60s/it]

[8, 3000] running_loss:0.073


loss: 0.072338:  75%|███████▍  | 3599/4800 [2:12:14<33:00,  1.65s/it]  

[8, 3600] running_loss:0.072


loss: 0.071263:  87%|████████▋ | 4199/4800 [2:33:06<24:30,  2.45s/it]  

[8, 4200] running_loss:0.071


loss: 0.070386: 100%|█████████▉| 4799/4800 [2:54:22<00:01,  1.73s/it]

[8, 4800] running_loss:0.070


loss: 0.070384: 100%|██████████| 4800/4800 [2:54:25<00:00,  2.18s/it]


Valid Accuracy: 94.5%

Epoch 9/10
-------------------------------


loss: 0.069392:  12%|█▏        | 599/4800 [22:14<2:28:34,  2.12s/it]

[9,  600] running_loss:0.069


loss: 0.068432:  25%|██▍       | 1199/4800 [43:28<1:31:17,  1.52s/it]

[9, 1200] running_loss:0.068


loss: 0.067526:  37%|███▋      | 1799/4800 [1:04:57<2:13:47,  2.67s/it]

[9, 1800] running_loss:0.068


loss: 0.066609:  50%|████▉     | 2399/4800 [1:26:09<1:13:01,  1.82s/it]

[9, 2400] running_loss:0.067


loss: 0.065735:  62%|██████▏   | 2999/4800 [1:48:38<1:30:58,  3.03s/it]

[9, 3000] running_loss:0.066


loss: 0.064874:  75%|███████▍  | 3599/4800 [2:10:05<34:04,  1.70s/it]  

[9, 3600] running_loss:0.065


loss: 0.064154:  87%|████████▋ | 4199/4800 [2:30:15<28:57,  2.89s/it]  

[9, 4200] running_loss:0.064


loss: 0.063352: 100%|█████████▉| 4799/4800 [2:52:49<00:01,  1.90s/it]  

[9, 4800] running_loss:0.063


loss: 0.063351: 100%|██████████| 4800/4800 [2:52:51<00:00,  2.16s/it]


Valid Accuracy: 94.4%

Epoch 10/10
-------------------------------


loss: 0.062551:  12%|█▏        | 599/4800 [22:53<2:25:25,  2.08s/it] 

[10,  600] running_loss:0.063


loss: 0.061786:  25%|██▍       | 1199/4800 [45:04<2:08:54,  2.15s/it]

[10, 1200] running_loss:0.062


loss: 0.061057:  37%|███▋      | 1799/4800 [1:08:22<2:12:12,  2.64s/it]

[10, 1800] running_loss:0.061


loss: 0.060337:  50%|████▉     | 2399/4800 [1:29:14<1:02:00,  1.55s/it]

[10, 2400] running_loss:0.060


loss: 0.059635:  62%|██████▏   | 2999/4800 [1:50:32<50:55,  1.70s/it]  

[10, 3000] running_loss:0.060


loss: 0.058938:  75%|███████▍  | 3599/4800 [2:11:40<52:54,  2.64s/it]  

[10, 3600] running_loss:0.059


loss: 0.058242:  87%|████████▋ | 4199/4800 [2:32:52<30:50,  3.08s/it]  

[10, 4200] running_loss:0.058


loss: 0.057600: 100%|█████████▉| 4799/4800 [2:52:47<00:02,  2.28s/it]

[10, 4800] running_loss:0.058


loss: 0.057598: 100%|██████████| 4800/4800 [2:52:50<00:00,  2.16s/it]


Valid Accuracy: 94.3%

Done!


# 2022.10.24 训练
batch=4，2400批，可以每300打印一次

In [9]:
from transformers import AdamW, get_scheduler

learning_rate = 1e-5 # 定义学习率
epoch_num = 10 # 轮次定义

loss_fn = nn.CrossEntropyLoss() # 损失函数，交叉熵
optimizer = AdamW(model.parameters(), lr=learning_rate) # Adamw一个常用的优化器
lr_scheduler = get_scheduler(
    "linear",# 使用线性的方式，慢慢往下降
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_acc = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)
    valid_acc = test_loop(valid_dataloader, model, mode='Valid')
    if valid_acc > best_acc:
        best_acc = valid_acc
        print('saving new weights...\n')
        # 保存模型
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_acc_{(100*valid_acc):0.1f}_model_weights.bin')
print("Done!")

# 它会去保存最好的那个模型

Epoch 1/10
-------------------------------


loss: 0.305601:  25%|██▍       | 599/2400 [57:17<2:44:34,  5.48s/it] 

[1,  600] running_loss:0.306


loss: 0.280461:  50%|████▉     | 1199/2400 [1:56:49<1:43:59,  5.20s/it]

[1, 1200] running_loss:0.280


loss: 0.279828:  51%|█████▏    | 1230/2400 [2:00:50<5:02:15, 15.50s/it]

RuntimeError: [enforce fail at alloc_cpu.cpp:66] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 737587200 bytes. Error code 12 (Cannot allocate memory)