## 数据预处理

In [1]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.backends.mps.is_available())

2.3.1+cpu
False
False


In [2]:
import torch
import time
from datetime import timedelta
import os
import pickle as pkl
from transformers import BertTokenizer
from tqdm import tqdm


### 1. 分词和构建词表

In [3]:
UNK, PAD, CLS = "[UNK]", "[PAD]", "[CLS]"  # 特殊符号
MAX_VOCAB_SIZE = 10000  # 词表长度限制

# 定义构建词表的函数
def build_vocab(file_path, tokenizer, max_size, min_freq):
    vocab_dic = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="构建词表中"):
            line = line.strip()
            if not line:
                continue
            content = line.split('\t')[0]  # 假设文本在每行的第一列
            for word in tokenizer.tokenize(content):  # 正确调用 tokenize 方法
                vocab_dic[word] = vocab_dic.get(word, 0) + 1

    # 根据频率和词表大小筛选词汇
    vocab_list = sorted(
        [(word, count) for word, count in vocab_dic.items() if count >= min_freq],
        key=lambda x: x[1],
        reverse=True
    )[:max_size]

    # 创建词表字典并加入特殊符号
    vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
    vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1, CLS: len(vocab_dic) + 2})
    return vocab_dic

# 使用 transformers 的 BertTokenizer 加载中文分词器
# tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer = BertTokenizer.from_pretrained('./bert_pretrain')

# 构建词表
vocab_dic = build_vocab('./data/train.txt', tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
print("生成的词表大小:", len(vocab_dic))
print(vocab_dic.items())

构建词表中: 180000it [00:20, 8691.25it/s]

生成的词表大小: 7768
dict_items([('：', 0), ('大', 1), ('国', 2), ('图', 3), ('(', 4), (')', 5), ('人', 6), ('年', 7), ('中', 8), ('新', 9), ('生', 10), ('金', 11), ('高', 12), ('[UNK]', 7766), ('《', 14), ('》', 15), ('上', 16), ('不', 17), ('考', 18), ('一', 19), ('日', 20), ('元', 21), ('开', 22), ('美', 23), ('价', 24), ('发', 25), ('学', 26), ('公', 27), ('成', 28), ('月', 29), ('将', 30), ('万', 31), ('基', 32), ('市', 33), ('出', 34), ('子', 35), ('行', 36), ('机', 37), ('业', 38), ('被', 39), ('家', 40), ('股', 41), ('的', 42), ('在', 43), ('网', 44), ('女', 45), ('期', 46), ('平', 47), ('房', 48), ('名', 49), ('三', 50), ('-', 51), ('会', 52), ('地', 53), ('场', 54), ('全', 55), ('小', 56), ('现', 57), ('有', 58), ('分', 59), ('后', 60), ('称', 61), ('组', 62), ('为', 63), ('下', 64), ('3', 65), ('盘', 66), ('最', 67), ('手', 68), ('2', 69), ('天', 70), ('本', 71), ('利', 72), ('首', 73), ('战', 74), ('长', 75), ('游', 76), ('海', 77), ('主', 78), ('起', 79), ('动', 80), ('北', 81), ('资', 82), ('售', 83), ('能', 84), ('重', 85), ('时', 86), ('男', 87), ('1', 88),




### 2. 构建数据集

In [4]:
def build_dataset(config):
    """
    根据提供的配置文件加载训练集、验证集和测试集，并对数据进行预处理。
    Args:
        config: 包含配置信息的对象，包含路径、分词器、pad_size 等。
    Returns:
        train, dev, test: 预处理后的训练集、验证集和测试集。
    """
    def load_dataset(path, pad_size=32):
        """
        加载并预处理单个数据集。
        Args:
            path: 数据文件路径。
            pad_size: 序列的最大长度。如果小于 pad_size，则进行填充；如果大于，则截断。
        Returns:
            contents: 包含 (token_ids, label, seq_len, mask) 的数据列表。
        """
        contents = []
        with open(path, "r", encoding='utf-8') as f:
            for line in tqdm(f):  # tqdm 用于显示进度条
                line = line.strip()
                if not line:
                    continue  # 跳过空行
                # 数据格式假定为 '文本\t标签'
                content, label = line.split('\t')
                token = config.tokenizer.tokenize(content)  # 分词操作
                token = [CLS] + token  # 在序列开头添加特殊标记 [CLS]
                seq_len = len(token)   # 序列长度
                token_ids = config.tokenizer.convert_tokens_to_ids(token)   # 转换为 ID
                mask = []   # 构建 mask 和 padding
                if pad_size:
                    if len(token) < pad_size:
                        # 如果序列长度不足 pad_size，填充 0
                        mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
                        token_ids += [0] * (pad_size - len(token))
                    else:
                        # 如果序列长度超过 pad_size，进行截断
                        mask = [1] * pad_size
                        token_ids = token_ids[:pad_size]
                        seq_len = pad_size

                # 将处理后的数据添加到结果列表
                contents.append((token_ids, int(label), seq_len, mask))

        return contents

    # 加载训练集、验证集和测试集
    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return train, dev, test
                

### 3. 数据封装

In [5]:
class DatasetIterater(object):
    def __init__(self, batches, batch_size, device, model_name):
        self.batch_size = batch_size
        self.batches = batches
        self.model_name = model_name
        self.n_batches = len(batches) // batch_size
        self.residue = False    # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True # batches不能被batch_size整除
        self.index = 0
        self.device = device
    
    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
        # pad 前的长度（超过pad_size的设置为pad_size）
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        if self.model_name == 'bert' or self.model_name == 'multi_task_bert':
            mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
            return (x, seq_len, mask), y
    
    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index * self.batch_size: len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches
        elif self.index >= self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index * self.batch_size : (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches
    
    def __iter__(self):
        return self
    
    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches

def build_iterator(dataset, config):
    iter = DatasetIterater(dataset, config.batch_size, config.device, config.model_name)
    return iter

In [6]:
def get_time_dif(start_time):
    # 获取已使用时间
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

## BERT分类模型搭建

### 1. 实现Config类代码

In [7]:
import torch
import torch.nn as nn
import os
from transformers import BertModel, BertTokenizer, BertConfig

class Config(object):
    def __init__(self, dataset):
        self.model_name = 'bert'
        self.data_path = './data/'
        self.train_path = self.data_path + "train.txt"  # 训练集
        self.dev_path = self.data_path + "dev.txt"  # 验证集
        self.test_path = self.data_path + "test.txt"    # 测试集
        self.class_list = [
            x.strip() for x in open(self.data_path + "class.txt").readlines()
        ]   # 类别名单
        self.save_path = './cache'
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)
        self.save_path += "/" + self.model_name + ".pt" # 模型训练结果
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.require_improvement = 1000 # 若超过1000batch效果还没有提升，则提前结束训练
        self.num_classes = len(self.class_list) # 类别数
        self.num_epochs = 3 # epoch数
        self.batch_size = 128   # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度（短补长截）
        self.learning_rate = 5e-5   # 学习率
        self.bert_path = './bert_pretrain'
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.bert_config = BertConfig.from_pretrained(self.bert_path + '/bert_config.json')
        self.hidden_size = 768


In [8]:
config = Config('toutiao')
# 构建词表
vocab_dic = build_vocab('./data/test.txt', tokenizer=config.tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
print("生成的词表大小:", len(vocab_dic))
print(config.device, config.train_path)

构建词表中: 10000it [00:01, 8904.64it/s]

生成的词表大小: 4676
cpu ./data/train.txt





### 2. 实现Model类

In [9]:
from transformers import AutoModel

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path, config=config.bert_config)
        # self.bert = BertModel.from_pretrained("bert-base-chinese")
        # self.bert = AutoModel.from_pretrained("bert-base-chinese")

        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        context = x[0]
        mask = x[2]
        # _, pooled = self.bert(context, attention_mask=mask)
        # out = self.fc(pooled)
        
        outputs = self.bert(context, attention_mask=mask)
        pooled = outputs.pooler_output
        # print(pooled.shape)
        out = self.fc(pooled)
        return out


## 编写训练、测试、评估函数

### 1. 训练函数

In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
# from utils import get_time_dif
from torch.optim import AdamW
from tqdm import tqdm
import math
import logging

In [11]:
def train(config, model, train_iter, dev_iter):
    start_time = time.time()
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01
        },
        {
            "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0
        }
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
    loss_fn = nn.CrossEntropyLoss()  # 损失函数定义一次

    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升，用于判断是否早停

    model.train()
    for epoch in range(config.num_epochs):
        print(f"Epoch [{epoch + 1}/{config.num_epochs}]")
        for i, (trains, labels) in enumerate(tqdm(train_iter)):
            # 解构输入，确保符合 forward() 方法
            # context, seq_len, mask = trains
            outputs = model(trains)
            
            model.zero_grad()  # 梯度清零
            loss = loss_fn(outputs, labels)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            if total_batch % 200 == 0 and total_batch != 0:
                # 每 200 轮输出在训练集和验证集上的效果
                true = labels.data.cpu()
                predict = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predict)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss  # 更新最佳验证集loss
                    torch.save(model.state_dict(), config.save_path)  # 保存模型
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ""
                time_dif = get_time_dif(start_time)
                msg = f"Iter: {total_batch}, Train Loss: {loss.item():.2f}, Train Acc: {train_acc:.2%}, Val Loss: {dev_loss:.2f}, Val Acc: {dev_acc:.2%}, Time: {time_dif} {improve}"
                print(msg)
                model.train()

            total_batch += 1

            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过指定batch没有下降，提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break

    

### 2. 测试函数

In [12]:
def test(config, model, test_iter):
    # model.load_state_dict(torch.load(config.save_path))
    # 采用量化模型进行推理时需要关闭
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    
    msg = "Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}"
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

### 3. 验证函数

In [13]:
def evaluate(config, model, data_iter, test=False):
    # 采用量化模型进行推理时需要关闭
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predict = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predict)
    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)

## 运行主函数

In [14]:
import time
import torch
import numpy as np
# from train_eval import train, test
from importlib import import_module
import argparse
# from utils import build_dataset, build_iterator, get_time_dif

### 1. 加载数据集

In [15]:
dataset = "toutiao" # 数据集
config = Config(dataset)
print(config.train_path)

print("Loading data for Bert Model...")
train_data, dev_data, test_data = build_dataset(config)
train_iter = build_iterator(train_data, config)
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)

./data/train.txt
Loading data for Bert Model...


180000it [00:28, 6232.96it/s]
10000it [00:01, 6455.64it/s]
10000it [00:01, 6512.19it/s]


In [16]:
print(len(train_iter), len(test_iter), len(dev_iter))
for i, (trains, labels) in enumerate(tqdm(train_iter)):
    # 解构输入，确保符合 forward() 方法
    context, seq_len, mask = trains
    print(context.shape, seq_len.shape, mask.shape)
    break
display(trains[0], trains[1], trains[2])

1407 79 79


  0%|          | 0/1407 [00:00<?, ?it/s]

torch.Size([128, 32]) torch.Size([128]) torch.Size([128, 32])





tensor([[ 101,  704, 1290,  ...,    0,    0,    0],
        [ 101,  697, 1921,  ...,    0,    0,    0],
        [ 101,  691,  126,  ...,    0,    0,    0],
        ...,
        [ 101,  783, 7183,  ...,    0,    0,    0],
        [ 101, 2458, 4669,  ...,    0,    0,    0],
        [ 101, 3136, 5509,  ...,    0,    0,    0]])

tensor([19, 23, 21, 25, 22, 21, 16, 22, 16, 12, 21, 23, 22, 16,  8, 17, 20, 24,
         8, 10, 18, 16, 24, 21, 18, 15, 11, 21, 19, 19, 22, 22, 17, 23, 24, 17,
        13, 18, 23, 19, 23, 21, 23, 21, 20, 14, 18, 16, 18, 24, 16, 23, 21, 17,
        16, 13, 23, 20, 21, 21, 13, 23, 18, 15, 25, 17, 21, 23, 23, 14, 20, 20,
        18, 17, 23, 15, 23, 21, 20, 15, 22, 21, 22, 20, 20, 15, 13, 21, 22, 15,
        21, 21, 23, 15, 23, 19, 17, 18, 14, 21, 14, 16, 21, 12, 17, 23, 15, 22,
        16, 16, 16, 18, 22, 16, 25, 17, 19, 18, 15, 18, 13, 22, 21, 14, 22, 17,
        16, 22])

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])

### 2. 实例化模型

In [17]:
model = Model(config)
print(model)
res = model(trains)
res.shape

Model(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=

torch.Size([128, 10])

In [18]:
res

tensor([[-0.3605,  0.0844, -0.2850,  ...,  0.0751, -0.4689, -0.4703],
        [-0.5437, -0.1852, -0.3007,  ..., -0.2121, -0.2325, -0.3381],
        [-0.5249,  0.2110, -0.0453,  ..., -0.2327, -0.1728, -0.5269],
        ...,
        [-0.8038,  0.3867,  0.1551,  ..., -0.1965, -0.1082, -0.2632],
        [-0.7965,  0.0941,  0.2110,  ..., -0.1926, -0.3037, -0.4126],
        [-0.5278, -0.0257, -0.4080,  ..., -0.0726, -0.5996, -0.4165]],
       grad_fn=<AddmmBackward0>)

In [19]:
# train(config, model, train_iter, dev_iter)

## 知识蒸馏

### 1. 分词和构建词表

In [20]:
import torch
import time
from datetime import timedelta
import os
import pickle as pkl
from transformers import BertTokenizer
from tqdm import tqdm


UNK, PAD, CLS = "[UNK]", "[PAD]", "[CLS]"  # 特殊符号
MAX_VOCAB_SIZE = 10000  # 词表长度限制

def build_vocab(file_path, tokenizer, max_size, min_freq):
    vocab_dic = {}
    with open(file_path, "r", encoding='utf-8') as f:
        for line in tqdm(f):
            line = line.strip()
            if not line:
                continue
            content = line.split("\t")[0]   # 获取文本内容
            for word in tokenizer(content):
                vocab_dic[word] = vocab_dic.get(word, 0) + 1
        vocab_list = sorted(
            [_ for _ in vocab_dic.items() if _[1]>=min_freq], key=lambda x: x[1], reverse=True
        )[:max_size]
        vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}   # word_to_idx 字典
        vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
    return vocab_dic


tokenizer = lambda x: [y for y in x]    # char-level
# 构建词表
vocab_dic = build_vocab('./data/train.txt', tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
print("生成的词表大小:", len(vocab_dic))
print(vocab_dic.items())


180000it [00:00, 270144.88it/s]

生成的词表大小: 4762
dict_items([(' ', 0), ('0', 1), ('1', 2), ('2', 3), ('：', 4), ('大', 5), ('国', 6), ('图', 7), ('(', 8), (')', 9), ('3', 10), ('人', 11), ('年', 12), ('5', 13), ('中', 14), ('新', 15), ('9', 16), ('生', 17), ('金', 18), ('高', 19), ('《', 20), ('》', 21), ('4', 22), ('上', 23), ('8', 24), ('不', 25), ('考', 26), ('一', 27), ('6', 28), ('日', 29), ('元', 30), ('开', 31), ('美', 32), ('价', 33), ('发', 34), ('学', 35), ('公', 36), ('成', 37), ('月', 38), ('将', 39), ('万', 40), ('7', 41), ('基', 42), ('市', 43), ('出', 44), ('子', 45), ('行', 46), ('机', 47), ('业', 48), ('被', 49), ('家', 50), ('股', 51), ('的', 52), ('在', 53), ('网', 54), ('女', 55), ('期', 56), ('平', 57), ('房', 58), ('名', 59), ('三', 60), ('-', 61), ('会', 62), ('地', 63), ('场', 64), ('全', 65), ('小', 66), ('现', 67), ('有', 68), ('分', 69), ('后', 70), ('称', 71), ('组', 72), ('为', 73), ('下', 74), ('盘', 75), ('最', 76), ('“', 77), ('”', 78), ('手', 79), ('天', 80), ('本', 81), ('利', 82), ('首', 83), ('战', 84), ('长', 85), ('游', 86), ('海', 87), ('主', 88), ('起',




### 2. 构建数据集

In [21]:
def bulid_dataset_CNN(config):
    tokenizer = lambda x: [y for y in x]    # char-level
    if os.path.exists(config.vocab_path):
        vocab = pkl.load(open(config.vocab_path, "rb"))
    else:
        vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
        pkl.dump(vocab, open(config.vocab_path, "wb"))
    
    def load_dataset(path, pad_size=32):
        contents = []
        with open(path, "r", encoding="utf-8") as f:
            for line in tqdm(f):
                line = line.strip()
                if not line:
                    continue
                content, label = line.split("\t")
                words_line = []
                token = tokenizer(content)
                seq_len = len(token)
                if pad_size:
                    if seq_len < pad_size:
                        token.extend([PAD] * (pad_size - seq_len))
                    else:
                        token = token[:pad_size]
                        seq_len = pad_size
                # word_to_idx
                for word in token:
                    words_line.append(vocab.get(word, vocab.get(UNK)))
                contents.append((words_line, int(label), seq_len))
        return contents
    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return vocab, train, dev, test
    

### 3. 数据封装

In [22]:
class DatasetIterater(object):
    def __init__(self, batches, batch_size, device, model_name):
        self.batch_size = batch_size
        self.batches = batches
        self.model_name = model_name
        self.n_batches = len(batches) // batch_size
        self.residue = False  # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True
        self.index = 0
        self.device = device

    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)

        # pad前的长度(超过pad_size的设为pad_size)
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        if self.model_name == "bert" or self.model_name == "multi_task_bert":
            mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
            return (x, seq_len, mask), y
        if self.model_name == "textCNN":
            return (x, seq_len), y

    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index * self.batch_size : len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

        elif self.index >= self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index * self.batch_size : (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

    def __iter__(self):
        return self

    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches


def build_iterator(dataset, config):
    iter = DatasetIterater(dataset, config.batch_size, config.device, config.model_name)
    return iter

In [23]:
def get_time_dif(start_time):
    # 获取已使用时间
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))