In [57]:
import jieba
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange, tqdm

In [58]:
# 数据预处理
vocab = {}
vocab['[PAD]'] = 0
# 参数设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seq_len = 35  # 每句话的长度
embedding_size = 150  # word2vec的维度
kernel_num = 75  # 每一种卷积核的个数
Kernel_list = [2, 3, 4, 5, 6]  # N-gram
class_num = 9  # 分类的个数
epoches = 1500  #训练次数
lr = 0.001  # 学习率

In [59]:
#词向量分配（每一个词都对应一个数字）
def pro_vocab(path):
    with open(path, 'r', encoding='utf-8') as f:
        L = [e.strip() for e in f.readlines()]
        for news in L:
            news = news[2:]
            for word in news:
                if word not in vocab:
                    vocab[word] = len(vocab)

In [60]:
pro_vocab('train.txt')
pro_vocab('test.txt')
pro_vocab('test_4.txt')

vocab_size = len(vocab)

In [61]:
def word2ids(text):  # 将数据转化成数字id 并多退少补
    ids = [vocab[word] for word in text]
    if len(ids) < seq_len:
        ids += [0] * (seq_len - len(ids))
        return ids
    else:
        return ids[:seq_len]

In [62]:
def load_data(path):  # 读取训练集和测试集
    with open(path, 'r', encoding='utf-8') as f:
        turples = [(int(sentence[0]), word2ids(sentence[2:].strip()))
                   for sentence in f.readlines()]
    labels = []
    texts = []
    for turple in turples:
        labels.append(turple[0])
        texts.append(turple[1])
    # 返回标签和经过预处理的文本
    return texts, labels

In [63]:
train_x, train_y = load_data('train.txt')
test_x, test_y = load_data('test.txt')
train_x = torch.LongTensor(train_x).to(device)
train_y = torch.LongTensor(train_y).to(device)
test_x = torch.LongTensor(test_x).to(device)
test_y = torch.LongTensor(test_y).to(device)

In [64]:
class TextCNN(nn.Module):
    def __init__(self):
        super(TextCNN, self).__init__()
        V = vocab_size
        E = embedding_size
        Ci = 1  # 输入数据的通道数
        Co = kernel_num  # 每一种卷积核的数目
        Kl = Kernel_list  # N-gram
        C = class_num  #输出的维度

        self.embed = nn.Embedding(V, E)
        self.convs = nn.ModuleList([nn.Conv2d(Ci, Co, (K, E)) for K in Kl])
        self.fc = nn.Linear(len(Kl) * Co, C)

    def forward(self, x):
        x = self.embed(x)  # (N, seq_len, E)

        x = x.unsqueeze(1)  # (N, Ci = 1, seq_len, E)

        x = [F.relu(conv(x)).squeeze(3)
             for conv in self.convs]  # [(N, Co, seq_len-ki+1), ...]*len(Ks)

        x = [F.max_pool1d(i, i.size(2)).squeeze(2)
             for i in x]  # [(N, Co), ...]*len(Ks)
        x = torch.cat(x, 1)
        out = F.softmax(self.fc(x), dim=1)
        return out

In [65]:
model = TextCNN().to(device)  #实例化模型
criterion = nn.CrossEntropyLoss().to(device)  #定义损失函数，为交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), lr=lr)  #使用随机梯度下降算法，学习率为lr

for epoch in trange(epoches): #训练过程
    pred = model(train_x)
    loss = criterion(pred, train_y)
    if (epoch + 1) % 100 == 0:
        print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

 67%|█████████████████████████████████████████████████████▎                          | 100/150 [00:22<00:11,  4.39it/s]

1.704274296760559


100%|████████████████████████████████████████████████████████████████████████████████| 150/150 [00:34<00:00,  4.38it/s]


In [66]:
pred_y = model(test_x)
L = pred_y.tolist()
total = len(L)
acc = 0
for i, pred in enumerate(L):
    max_index = pred.index(max(pred))
    if max_index == test_y[i]:
        acc += 1

print(acc / total * 100, '%')

54.75 %


In [87]:
import operator

In [93]:
with open('test_4.txt', 'r', encoding='utf-8') as f:
    all_data = f.readlines()
    id_ = []
    texts = []
    for i, line in enumerate(all_data):
        r = line.split()
        id_.append(int(r[0]))
        texts.append(word2ids(r[1]))
#转化为tensor的类型        
#id_ = torch.LongTensor(id_).to(device)
texts = torch.LongTensor(texts).to(device)

In [95]:
# 我们生成一个新的数据文本，并将所有新闻标题写入新文件
pred = model(texts)
L = pred.tolist()
with open('10185102144-predictions.txt', 'w',encoding='utf-8') as f:
    for i, pred in enumerate(L):
        max_index = pred.index(max(pred))
        print(id_[i])
        f.write(str(id_[i])+'\t'+str(max_index)+'\n')

3875
2999
356
483
130
2197
3188
1462
4475
3139
3289
4214
2275
1919
3147
550
4820
2224
3024
2704
1451
193
3035
952
1629
333
455
3861
2376
4497
4235
1375
1890
892
3239
1424
1204
1134
1578
4339
1807
1390
453
2329
4350
4910
2273
1214
2899
2038
4720
3909
1293
1641
3173
2437
3231
2312
2358
4155
2660
3647
208
2564
4146
3076
2811
4303
1704
3006
2372
1619
2810
185
4782
142
2748
3994
3331
1761
2440
3274
1168
3297
4922
2121
2156
2381
1430
1469
2230
4299
3866
4667
1813
1728
4766
457
2292
4067
1748
3517
2432
2045
94
377
70
2104
3025
4428
4205
2713
767
2881
1256
1146
1762
1987
3020
3432
3864
1215
1715
2068
1876
267
2400
1011
113
3931
738
661
1460
1015
2843
541
2409
3926
3720
1635
3425
2387
884
1031
474
201
3908
3965
306
3401
2326
549
95
1265
610
3985
3441
2596
4565
958
2261
4935
3871
4916
2472
2643
2614
3874
112
3767
1983
4343
1673
2119
2752
741
2327
2681
1301
2514
3822
2077
2317
1089
2529
725
3108
164
397
163
3824
1041
881
805
2174
4328
1478
1065
3855
2887
4989
3227
2151
447
1835
3558
1386
852
3261

In [96]:
pred_y = model(test_x)
L = pred_y.tolist()
total = len(L)
for i, pred in enumerate(L):
    max_index = pred.index(max(pred))
    print(max_index)

0
0
0
1
3
1
1
1
1
1
1
1
1
0
3
0
4
1
3
3
0
3
3
1
3
3
3
3
3
4
4
4
4
4
4
4
4
4
4
4
4
1
1
1
1
3
3
3
3
3
4
4
0
3
1
1
4
0
1
1
4
3
0
0
1
0
1
0
0
1
0
0
0
0
3
3
0
3
3
3
4
4
4
4
4
1
1
1
1
3
3
0
0
1
1
1
1
1
1
1
1
1
3
3
3
4
4
4
4
1
3
4
0
0
4
3
1
3
3
3
3
4
4
4
1
4
1
4
1
1
0
0
1
3
3
0
1
0
1
1
1
3
0
3
0
3
3
4
1
1
3
1
4
1
0
4
4
0
1
1
1
3
3
4
3
0
1
0
0
0
0
1
1
0
1
1
0
0
1
0
3
3
3
4
4
1
1
3
1
0
3
0
0
1
0
1
1
1
1
3
3
3
1
4
4
1
3
0
3
3
3
3
0
1
3
3
3
4
0
3
1
3
4
4
4
1
3
0
0
1
1
0
0
0
1
3
3
4
4
4
4
4
1
0
1
1
3
1
3
3
1
3
3
0
1
1
1
1
1
1
1
1
0
0
1
3
3
0
3
3
3
3
3
0
4
3
4
4
4
4
4
4
4
4
4
1
3
3
0
0
0
0
0
0
1
1
0
4
3
3
3
0
3
3
3
4
4
3
3
3
1
1
1
4
3
3
0
3
0
3
3
0
1
1
1
0
0
0
3
0
3
3
3
0
3
3
3
4
4
1
4
4
4
1
0
1
1
3
3
4
4
1
3
1
0
0
0
1
3
1
3
3
3
3
1
1
1
1
3
0
0
1
0
0
1
1
1
0
1
0
0
3
3
3
4
4
3
4
4
1
1
3
1
3
1
1
3
0
3
0


In [97]:
L

[[0.9170469641685486,
  0.042246319353580475,
  2.1580615339189535e-06,
  0.023305805400013924,
  0.017398467287421227,
  1.3457049874432414e-07,
  6.342296643424561e-08,
  1.1213004746934985e-08,
  7.580101879511858e-08],
 [0.7219216227531433,
  0.020956408232450485,
  3.682677561300807e-05,
  0.038834430277347565,
  0.21824847161769867,
  9.444895567867206e-07,
  3.835621100733988e-07,
  3.03389612099636e-07,
  7.117919835764042e-07],
 [0.4783562421798706,
  0.06988511234521866,
  9.423039955436252e-06,
  0.438825398683548,
  0.012922994792461395,
  3.257255798416736e-07,
  1.607358228739031e-07,
  3.265381209871521e-08,
  2.617375969293789e-07],
 [0.005358982365578413,
  0.9888784289360046,
  3.0044752747926395e-06,
  0.004369549918919802,
  0.001389788230881095,
  1.0175405407153448e-07,
  9.054095073679491e-08,
  3.113680691058107e-08,
  4.134653508458541e-08],
 [0.40237554907798767,
  0.08463398367166519,
  4.439756139618112e-06,
  0.480300635099411,
  0.0326848030090332,
  3.573