# 第一步: 准备配置: 数据集, 模型, 训练参数

In [8]:
import torch
from transformers import BertTokenizer

class Config(object):
    def __init__(self, dataset):
        self.model_name = 'bert'
        self.train_path = '../data/caruser/train.txt'
        self.test_path = '../data/caruser/test.txt'
        self.val_path = '../data/caruser/val.txt'
        self.save_path = '/saved_dict/' + self.model_name + '.ckpt'
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.require_improvement = 1000
        self.num_epochs = 3
        self.batch_size = 128
        self.learning_rate = 5e-5
        self.bert_path = './bert-base-chinese'
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.hidden_size = 768 

# 第二步: 定义模型

In [9]:
from torch import nn
from transformers import BertModel
from torchviz import make_dot

class Model(nn.Module):

    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path)
        for name, param in self.bert.named_parameters():
            param.requires_grad = True
        self.topic_fc = nn.Linear(config.hidden_size, 10)
        self.emo_fc = nn.Linear(config.hidden_size, 3)


    def forward(self, x):
        context = x[0]
        mask = x[2]

        outputs = self.bert(context, attention_mask=mask)
        pooled = outputs.pooler_output
        topic_out = self.topic_fc(pooled)
        emo_out = self.emo_fc(pooled)
        make_dot(outputs, params=dict(self.named_parameters())).render("bert_model", format="pdf")
        return topic_out, emo_out

# 第三步: 处理数据

In [10]:
from tqdm import tqdm


def build_dataset(config):
    def load_dataset(path, pad_size=32):
        contents = []
        with open(path, "r", encoding="UTF-8") as f:
            for line in tqdm(f):
                lin = line.strip()
                if not lin:
                    continue
                content, *labels = lin.split("\t")
                token = config.tokenizer.tokenize(content)
                token = ["[CLS]"] + token
                seq_len = len(token)
                token_ids = config.tokenizer.convert_tokens_to_ids(token)

                topic_list = ["操控", "内饰", "安全性", "空间", "舒适性", "外观", "动力", "价格", "配置", "油耗"]
                topic_labels = [0] * len(topic_list)
                emotion_score_val = 0
                for label in labels:
                    topic, score = label.split("#")
                    if topic in topic_list:
                        topic_labels[topic_list.index(topic)] = 1
                        emotion_score_val += int(score)  # 累加情感分数
                emotion_score = [1, 0, 0] if emotion_score_val < 0 else ([0, 1, 0] if emotion_score_val == 0 else [0, 0, 1])
                
                # 填充 token_ids 和 mask
                if len(token_ids) < pad_size:
                    token_ids += [0] * (pad_size - len(token_ids))  # 填充到 pad_size
                    mask = [1] * len(token_ids) + [0] * (
                        pad_size - len(token_ids)
                    )  # mask 也要填充
                else:
                    token_ids = token_ids[:pad_size]  # 截断到 pad_size
                    mask = [1] * pad_size  # mask 长度固定为 pad_size
                    seq_len = pad_size

                contents.append((token_ids, seq_len, mask, topic_labels, emotion_score))
        return contents

    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return train, dev, test