In [163]:
import jieba
import numpy as np
import random
import torch
import torch.nn as nn
from torch.nn.functional import pad
import torch.nn.functional as F
import codecs
import json
import re
from config import Config

In [153]:
maxlen = 20
seed = 2022
vocab_dim = 512
window_sizes = [3, 3, 5, 7]

In [105]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [106]:
def load_dictionaries():
    """返回每个词语的索引，词向量，以及每个句子所对应的词语索引"""
    print("加载词典")
    w2indx_json = codecs.open(Config.w2indx_path, 'r', encoding=Config.encoding)
    w2vec_json = codecs.open(Config.w2vec_path, 'r', encoding=Config.encoding)

    w2indx = json.load(w2indx_json)
    W2VEC = json.load(w2vec_json)
    w2vec = dict()
    for key, value in W2VEC.items():
        w2vec[key] = np.asarray(value)

    w2indx_json.close()
    w2vec_json.close()

    return w2indx, w2vec

In [140]:
def parse_dataset(combined):
    """将combined中的数据转换为索引表示"""
    data = []
    for sentence in combined:
        new_txt = []
        for word in sentence:
            try:
                new_txt.append(w2indx[word])
            except:
                new_txt.append(0)
        if len(new_txt) > maxlen:
            new_txt=torch.Tensor(new_txt[:maxlen]).int()
        elif len(new_txt) < maxlen:
            new_txt = torch.Tensor(new_txt).int()
            new_txt = pad(new_txt, (0, maxlen - new_txt.shape[0]))

        data.append(new_txt)
    return data

In [181]:
def get_embeddings(combined, embedding_weights):
    """根据combined获取embedding"""
    inputs = [embedding_weights[sen] for sen in combined]
    inputs = torch.Tensor(inputs)
    print(inputs)
    return inputs


In [182]:
def regular(str_list: list) -> list:
    """
    句子规范化，主要是对原始语料的句子进行一些标点符号的统一处理
    """
    sen = []
    for index, line in enumerate(str_list):
        line = re.sub(r'…{1,100}', '…', line)
        line = re.sub(r'\.{3,100}', '…', line)
        line = re.sub(r'···{2,100}', '…', line)
        line = re.sub(r'\.{1,100}', '。', line)
        line = re.sub(r'。{1,100}', '。', line)
        line = re.sub(r'？{1,100}', '？', line)
        line = re.sub(r'!{1,100}', '！', line)
        line = re.sub(r'！{1,100}', '！', line)
        line = re.sub(r'~{1,100}', '～', line)
        line = re.sub(r'～{1,100}', '～', line)
        line = re.sub(r'\d*\.\d+|\d+', '1', line)  # 将所有数字都替换成1
        sen.append(line)
    return sen

In [183]:
def get_inputs(str_list: list, embedding_weights):
    # 句子规范化
    str_list = regular(str_list)
    # 分词处理
    combined = []
    for s in str_list:
        words = jieba.lcut(s)
        words = " ".join(words)
        combined.append(words)

    # 将词转换成对应的id
    combined = parse_dataset(str_list)  # 将combined中的数据转换为索引表示
    combined = nn.utils.rnn.pad_sequence(combined, batch_first=True, padding_value=0)
    print(combined.shape)

    # 将词id转换成对应的embedding
    inputs = get_embeddings(combined, embedding_weights)
    print(inputs.shape)

    return inputs

In [228]:
def Print(inputs, outputs):
    emo_dict = {0: 'null', 1: 'like', 2: 'sad', 3: 'disgust', 4: 'angry', 5: 'happy'}
    outputs = outputs.argmax(axis=1).cpu()

    for i in range(len(inputs)):
        print(f'{inputs[i]} --> {emo_dict[outputs[i].item()]}')

In [219]:
class TextCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList([
                nn.Sequential(nn.Conv1d(in_channels=vocab_dim,
                                        out_channels=256,
                                        kernel_size=h),
                              nn.Dropout(0.2),
                              nn.BatchNorm1d(num_features=256),
                              nn.ReLU(),
                              nn.MaxPool1d(kernel_size=maxlen-h+1))
                     for h in window_sizes
                    ])
        self.fc = nn.Linear(in_features=256 * len(window_sizes),
                            out_features=6)

    def forward(self, x):
        # inputs = [batch, maxlen, vocab_dim]

        # SpatialDropout1D
        x = x.permute(0, 2, 1)   #  [batch, vocab_dim, maxlen]
        x = F.dropout2d(x, 0.3, training=self.training)
        x = x.permute(0, 2, 1)   # back to  [batch, maxlen, vocab_dim]

        # batch_size x text_len x embedding_size  -> batch_size x embedding_size x text_len
        x = x.permute(0, 2, 1)

        out = torch.cat([conv(x).squeeze(-1) for conv in self.convs], dim=1)
        out = self.fc(out)

        return out

In [220]:
def predict(str_list, embedding_weights):
    print('loading model......')
    model = TextCNN()
    model.load_state_dict(torch.load("model/model.pth")['state_dict']) # 只保存了训练参数
    model.cuda()

    # 处理输入
    inputs = get_inputs(str_list, embedding_weights)
    inputs = inputs.cuda()

    # 进行预测
    model.eval()
    outputs = model(inputs)

    # 输出结果
    Print(str_list, outputs)
    return outputs


In [204]:
w2indx, w2vec = load_dictionaries()
n_symbols = len(w2indx) + 1
embedding_weights = np.zeros((n_symbols, vocab_dim))
for word, index in w2indx.items():  # 从索引为1的词语开始，对每个词语对应其词向量
    embedding_weights[index, :] = w2vec[word]

加载词典


In [229]:
str_list = ['一只小泰迪', '你想干什么？!', ' 猫猫真可爱！', '请问您今天要来点兔子吗？', '喜欢一个人是隐藏不住的。', '傻逼', "今天天气真好！"]
outputs = predict(str_list, embedding_weights)  # 预测的是一个列表

loading model......
torch.Size([7, 20])
tensor([[[ 3.8169e-01, -4.0005e-01,  2.9346e-01,  ..., -2.2200e+00,
           9.3815e-02, -7.7787e-01],
         [-6.3990e-01,  6.0148e-02,  1.3081e+00,  ..., -1.3306e+00,
          -2.8541e-01,  3.2685e-01],
         [-4.1571e-02, -6.4343e-01,  7.6042e-01,  ..., -1.4898e+00,
          -3.7161e-01, -1.8120e-01],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-6.8756e-01,  4.7396e-01,  1.1422e+00,  ..., -2.0392e+00,
          -1.8019e-01, -1.9962e+00],
         [-1.3168e-01,  1.2774e+00,  1.6794e-01,  ...,  3.3911e-01,
           1.5143e-01, -1.0711e+00],
         [-1.3465e+00,  1.5103e+00,  5.1551e-01,  ...,  6.1945e-01,
           3.2719e-01, -1.1235e+00],

