In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
from  torch.utils.data import Dataset,DataLoader
from tqdm import tqdm

In [148]:
def read_data(train_or_test,num=None):
    with open(os.path.join(".","data",train_or_test + ".txt"),encoding="utf-8") as f:
        all_data = f.read().split("\n")

    texts = []
    labels = []
    for data in all_data:
        if data:
            t,l = data.split("\t")
            texts.append(t)
            labels.append(l)
    if num == None:
        return texts,labels
    else:
        return texts[:num],labels[:num]

构建语料库

In [149]:
def built_corpus(train_texts,embedding_num):
    word_2_index = {"<PAD>":0,"<UNK>":1}
    for text in train_texts:
        for word in text:
            word_2_index[word] = word_2_index.get(word,len(word_2_index))
    return word_2_index,nn.Embedding(len(word_2_index),embedding_num)

构建数据集

In [150]:
class TextDataset(Dataset):
    def __init__(self,all_text,all_label,word_2_index,max_len):
        self.all_text = all_text
        self.all_label = all_label
        self.word_2_index = word_2_index
        self.max_len = max_len

    def __getitem__(self,index):
        text = self.all_text[index][:self.max_len]
        label = int(self.all_label[index])

        text_idx = [self.word_2_index.get(i,1) for i in text]
        text_idx = text_idx + [0] * (self.max_len - len(text_idx))

        text_idx = torch.tensor(text_idx).unsqueeze(dim=0)

        return text_idx,label

    def __len__(self):
        return len(self.all_text)

构建模块

In [151]:
class Block(nn.Module):
    def __init__(self,kernel_s,embeddin_num,max_len,hidden_num):
        super().__init__()
        self.cnn = nn.Conv1d(in_channels=embeddin_num,out_channels=hidden_num,kernel_size=kernel_s) #  200 * 50 * 20 (batch *  embedding_num * text_len  )
        self.act = nn.ReLU()
        self.mxp = nn.MaxPool1d(kernel_size=(max_len-kernel_s+1))

    def forward(self,batch_emb): # 200 * 50 * 20 (batch_size * embedding_size * text_len )
        c = self.cnn.forward(batch_emb)
        a = self.act.forward(c)
        a = a.squeeze(dim=-1)
        m = self.mxp.forward(a)
        m = m.squeeze(dim=-1)
        return m

模块组成模型

In [152]:
class TextCNNModel(nn.Module):
    def __init__(self,emb_matrix,max_len,class_num,hidden_num):
        super().__init__()
        self.emb_num = emb_matrix.weight.shape[1]

        self.block1 = Block(2,self.emb_num,max_len,hidden_num)
        self.block2 = Block(3,self.emb_num,max_len,hidden_num)
        self.block3 = Block(4,self.emb_num,max_len,hidden_num)
        self.block4 = Block(5, self.emb_num, max_len, hidden_num)

        self.emb_matrix = emb_matrix

        self.classfier = nn.Linear(hidden_num*4,class_num) # 全连接层，可以视为分类器
        self.loss_fun = nn.CrossEntropyLoss()
    
    def forward(self,batch_idx,batch_label=None):
        batch_emb = self.emb_matrix(batch_idx) # 输入维数 200 * 50 * 20
        batch_emb = batch_emb.permute(0, 2, 1)
        b1_result = self.block1.forward(batch_emb)
        b2_result = self.block2.forward(batch_emb)
        b3_result = self.block3.forward(batch_emb)
        b4_result = self.block4.forward(batch_emb)

        feature = torch.cat([b1_result,b2_result,b3_result,b4_result],dim=1) # 1 * 6 : [ batch * (3 * 2)]
        pre = self.classfier(feature) # 存疑 讲解为概率，权重，与下面的预测值间的区别不是很清楚
        
        # 如果有标签，那么就输出损失值；否则输出预测值
        if batch_label is not None:
            loss = self.loss_fun(pre,batch_label)
            return loss
        else:
            return torch.argmax(pre,dim=-1)
        

In [153]:
train_text,train_label = read_data("train")
validation_text,validation_label =  read_data("dev")

embedding = 50
max_len= 20
batch_size = 200
epoch = 100
lr = 0.001
hidden_num = 2
class_num = len(set(train_label))
device = "cuda:0" if torch.cuda.is_available() else "cpu"

word_2_index,words_embedding = built_corpus(train_text,embedding)

train_dataset = TextDataset(train_text,train_label,word_2_index,max_len)
train_loader = DataLoader(train_dataset,batch_size,shuffle=False)

validation_dataset = TextDataset(validation_text,validation_label,word_2_index,max_len)
validation_loader = DataLoader(validation_dataset,batch_size,shuffle=False)


model = TextCNNModel(words_embedding,max_len,class_num,hidden_num).to(device)
opt = torch.optim.AdamW(model.parameters(),lr=lr)


for e in range(epoch):
    print(f"epoc {e}")

    for batch_idx,batch_label in train_loader:
        batch_idx = torch.reshape(batch_idx,(batch_size,max_len))
        batch_idx = batch_idx.to(device)
        batch_label = batch_label.to(device)
        loss = model.forward(batch_idx,batch_label)
        loss.backward()
        opt.step()
        opt.zero_grad()

    print(f"loss:{loss:.3f}")

    right_num = 0
    for batch_idx,batch_label in validation_loader:
        batch_idx = torch.reshape(batch_idx,(batch_size,max_len))
        batch_idx = batch_idx.to(device)
        batch_label = batch_label.to(device)
        pre = model.forward(batch_idx)
        right_num += int(torch.sum(pre==batch_label))

    print(f"acc = {right_num/len(validation_text)*100:.2f}%")
    print()

epoc 0
loss:1.231
acc = 63.68%

epoc 1
loss:0.994
acc = 70.87%

epoc 2
loss:0.873
acc = 73.52%

epoc 3
loss:0.811
acc = 75.32%

epoc 4
loss:0.759
acc = 76.60%

epoc 5
loss:0.724
acc = 77.46%

epoc 6
loss:0.705
acc = 78.09%

epoc 7
loss:0.685
acc = 78.58%

epoc 8
loss:0.672
acc = 78.93%

epoc 9
loss:0.648
acc = 79.32%

epoc 10
loss:0.632
acc = 79.52%

epoc 11
loss:0.618
acc = 79.66%

epoc 12
loss:0.606
acc = 79.78%

epoc 13
loss:0.589
acc = 79.85%

epoc 14
loss:0.573
acc = 79.77%

epoc 15
loss:0.559
acc = 80.16%

epoc 16
loss:0.549
acc = 80.25%

epoc 17
loss:0.532
acc = 80.30%

epoc 18
loss:0.521
acc = 80.41%

epoc 19
loss:0.509
acc = 80.38%

epoc 20
loss:0.499
acc = 80.49%

epoc 21
loss:0.490
acc = 80.48%

epoc 22
loss:0.483
acc = 80.57%

epoc 23
loss:0.470
acc = 80.53%

epoc 24
loss:0.463
acc = 80.56%

epoc 25
loss:0.454
acc = 80.57%

epoc 26
loss:0.447
acc = 80.50%

epoc 27
loss:0.439
acc = 80.63%

epoc 28
loss:0.435
acc = 80.49%

epoc 29
loss:0.428
acc = 80.53%

epoc 30
loss:0.420
a