In [14]:
import jieba
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange, tqdm

In [15]:
# 数据预处理
vocab = {}
vocab['[PAD]'] = 0
# 参数设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seq_len = 35  # 每句话的长度
embedding_size = 150  # word2vec的维度
kernel_num = 75  # 每一种卷积核的个数
Kernel_list = [2, 3, 4, 5, 6]  # N-gram
class_num = 9  # 分类的个数
epoches = 1500  #训练次数
lr = 0.001  # 学习率

In [16]:
#词向量分配（每一个词都对应一个数字）
def pro_vocab(path):
    with open(path, 'r', encoding='utf-8') as f:
        L = [e.strip() for e in f.readlines()]
        for news in L:
            news = news[2:]
            for word in news:
                if word not in vocab:
                    vocab[word] = len(vocab)

In [17]:
pro_vocab('train.txt')
pro_vocab('test.txt')
pro_vocab('test_4.txt')

vocab_size = len(vocab)

In [18]:
def word2ids(text):  # 将数据转化成数字id 并多退少补
    ids = [vocab[word] for word in text]
    if len(ids) < seq_len:
        ids += [0] * (seq_len - len(ids))
        return ids
    else:
        return ids[:seq_len]

In [19]:
def load_data(path):  # 读取训练集和测试集
    with open(path, 'r', encoding='utf-8') as f:
        turples = [(int(sentence[0]), word2ids(sentence[2:].strip()))
                   for sentence in f.readlines()]
    labels = []
    texts = []
    for turple in turples:
        labels.append(turple[0])
        texts.append(turple[1])
    # 返回标签和经过预处理的文本
    return texts, labels

In [20]:
train_x, train_y = load_data('train.txt')
test_x, test_y = load_data('test.txt')
train_x = torch.LongTensor(train_x).to(device)
train_y = torch.LongTensor(train_y).to(device)
test_x = torch.LongTensor(test_x).to(device)
test_y = torch.LongTensor(test_y).to(device)

In [21]:
class TextCNN(nn.Module):
    def __init__(self):
        super(TextCNN, self).__init__()
        V = vocab_size
        E = embedding_size
        Ci = 1  # 输入数据的通道数
        Co = kernel_num  # 每一种卷积核的数目
        Kl = Kernel_list  # N-gram
        C = class_num  #输出的维度

        self.embed = nn.Embedding(V, E)
        self.convs = nn.ModuleList([nn.Conv2d(Ci, Co, (K, E)) for K in Kl])
        self.fc = nn.Linear(len(Kl) * Co, C)

    def forward(self, x):
        x = self.embed(x)  # (N, seq_len, E)

        x = x.unsqueeze(1)  # (N, Ci = 1, seq_len, E)

        x = [F.relu(conv(x)).squeeze(3)
             for conv in self.convs]  # [(N, Co, seq_len-ki+1), ...]*len(Ks)

        x = [F.max_pool1d(i, i.size(2)).squeeze(2)
             for i in x]  # [(N, Co), ...]*len(Ks)
        x = torch.cat(x, 1)
        out = F.softmax(self.fc(x), dim=1)
        return out

In [22]:
model = TextCNN().to(device)  #实例化模型
criterion = nn.CrossEntropyLoss().to(device)  #定义损失函数，为交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), lr=lr)  #使用随机梯度下降算法，学习率为lr

for epoch in trange(epoches): #训练过程
    pred = model(train_x)
    loss = criterion(pred, train_y)
    if (epoch + 1) % 100 == 0:
        print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  7%|█████▎                                                                         | 100/1500 [00:23<05:24,  4.32it/s]

1.6990526914596558


 13%|██████████▌                                                                    | 200/1500 [00:46<05:01,  4.31it/s]

1.6950310468673706


 20%|███████████████▊                                                               | 300/1500 [01:09<04:40,  4.27it/s]

1.6937613487243652


 27%|█████████████████████                                                          | 400/1500 [01:32<04:17,  4.27it/s]

1.6104191541671753


 33%|██████████████████████████▎                                                    | 500/1500 [01:56<03:54,  4.27it/s]

1.609690546989441


 40%|███████████████████████████████▌                                               | 600/1500 [02:19<03:32,  4.24it/s]

1.6094458103179932


 47%|████████████████████████████████████▊                                          | 700/1500 [02:43<03:08,  4.24it/s]

1.6090658903121948


 53%|██████████████████████████████████████████▏                                    | 800/1500 [03:07<02:45,  4.24it/s]

1.608972191810608


 60%|███████████████████████████████████████████████▍                               | 900/1500 [03:30<02:22,  4.21it/s]

1.608913779258728


 67%|████████████████████████████████████████████████████                          | 1000/1500 [03:54<01:58,  4.23it/s]

1.6088753938674927


 73%|█████████████████████████████████████████████████████████▏                    | 1100/1500 [04:17<01:35,  4.19it/s]

1.608847975730896


 80%|██████████████████████████████████████████████████████████████▍               | 1200/1500 [04:41<01:10,  4.23it/s]

1.5151641368865967


 87%|███████████████████████████████████████████████████████████████████▌          | 1300/1500 [05:05<00:47,  4.20it/s]

1.5147992372512817


 93%|████████████████████████████████████████████████████████████████████████▊     | 1400/1500 [05:29<00:23,  4.21it/s]

1.5146818161010742


100%|██████████████████████████████████████████████████████████████████████████████| 1500/1500 [05:52<00:00,  4.25it/s]

1.5146257877349854





In [23]:
pred_y = model(test_x)
L = pred_y.tolist()
total = len(L)
acc = 0
for i, pred in enumerate(L):
    max_index = pred.index(max(pred))
    if max_index == test_y[i]:
        acc += 1

print(acc / total * 100, '%')

63.5 %


In [24]:
with open('test_4.txt', 'r', encoding='utf-8') as f:
    all_data = f.readlines()
    id_ = []
    texts = []
    for i, line in enumerate(all_data):
        r = line.split()
        id_.append(int(r[0]))
        texts.append(word2ids(r[1]))
#转化为tensor的类型        
#id_ = torch.LongTensor(id_).to(device)
texts = torch.LongTensor(texts).to(device)

In [25]:
# 我们生成一个新的数据文本，并将所有新闻标题写入新文件
pred = model(texts)
L = pred.tolist()
with open('10185102144-predictions.txt', 'w',encoding='utf-8') as f:
    for i, pred in enumerate(L):
        max_index = pred.index(max(pred))
        f.write(str(id_[i])+'\t'+str(max_index)+'\n')

In [26]:
pred_y = model(test_x)
L = pred_y.tolist()
total = len(L)
for i, pred in enumerate(L):
    max_index = pred.index(max(pred))
    print(max_index)

0
0
0
1
8
1
1
1
1
1
0
1
2
2
2
0
8
2
3
3
0
3
3
1
3
3
3
3
2
4
4
4
4
4
4
4
4
4
4
4
4
1
1
1
1
3
3
8
8
2
8
8
2
4
1
1
8
8
8
0
8
8
2
3
1
0
2
0
1
1
0
2
2
2
0
3
3
3
3
3
4
4
4
4
4
1
1
1
8
8
8
0
1
2
8
1
1
1
1
1
1
1
3
3
8
4
4
4
4
1
3
8
0
0
2
1
1
0
3
3
3
4
4
4
1
8
2
8
8
8
0
0
1
2
3
0
1
2
1
1
1
3
0
3
3
3
3
4
1
3
8
3
2
1
3
8
4
0
1
1
1
2
3
4
3
8
8
0
0
0
8
1
1
0
1
2
2
2
1
8
3
4
3
4
4
1
1
3
1
8
3
0
3
2
0
1
1
1
2
3
1
3
3
4
4
1
3
8
3
3
2
1
8
2
2
0
2
2
2
3
2
3
4
4
4
1
8
0
0
1
1
2
2
3
4
3
2
4
4
4
4
4
1
1
1
1
3
3
3
8
8
3
3
0
1
1
1
1
1
1
1
1
2
2
2
2
3
2
3
3
3
0
3
2
8
0
4
4
4
4
4
4
4
4
4
8
4
3
8
8
2
2
0
0
1
1
0
8
3
3
3
0
3
3
3
4
4
1
0
3
1
1
1
3
3
8
2
2
0
2
3
0
1
1
1
1
2
2
3
2
3
3
3
1
3
3
2
4
4
0
4
4
4
1
2
2
2
3
2
4
4
1
8
1
8
0
0
1
3
1
3
3
3
3
1
1
1
1
8
2
0
2
0
2
1
1
1
0
2
2
2
1
3
3
4
4
2
4
4
1
1
3
8
1
0
8
8
8
8
1


In [27]:
L

[[0.9117724895477295,
  2.9434577299980447e-05,
  0.04774027690291405,
  0.016126494854688644,
  0.00017451911116950214,
  1.9449129240456386e-07,
  7.782991140459217e-09,
  1.3334314763469024e-09,
  0.02415647730231285],
 [0.5403991341590881,
  7.74168256612029e-06,
  0.3622424006462097,
  0.01182844303548336,
  0.0496782585978508,
  1.4866269566482515e-06,
  1.6236866429153451e-07,
  5.045613349352607e-08,
  0.03584228828549385],
 [0.8383316993713379,
  0.001463268301449716,
  0.024242054671049118,
  0.129526749253273,
  5.877584044355899e-05,
  5.680774606275918e-08,
  5.1622897068170914e-09,
  5.205673669905764e-10,
  0.006377420388162136],
 [3.320997620903654e-06,
  0.9993207454681396,
  0.0006036202539689839,
  1.2471550689951982e-05,
  2.7846312150359154e-05,
  1.7229327964329855e-09,
  5.796029628468702e-11,
  8.399847484241718e-11,
  3.191971336491406e-05],
 [0.017637982964515686,
  0.012178846634924412,
  0.13395792245864868,
  0.00013988764840178192,
  0.05808670446276665,
 