In [1]:
import os
import jieba
import random
import torch
from torch import nn
from tqdm import tqdm

In [2]:
def get_file_path(category, frac=0.2):
    dir_path = f'THUCNews/{category}/'
    file_list = os.listdir(dir_path)
    file_list = random.sample(file_list, int(frac * len(file_list)))
    file_list = [dir_path + file for file in file_list]
    return file_list

categories = os.listdir('../THUCNews/')[:3]
print(categories)
get_file_path(categories[1])[:5]

['时尚', '家居', '教育']


['../THUCNews/家居/240053.txt',
 '../THUCNews/家居/225706.txt',
 '../THUCNews/家居/244168.txt',
 '../THUCNews/家居/236354.txt',
 '../THUCNews/家居/231429.txt']

In [3]:
def get_tokens(file):
    res = []
    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                tokens = list(jieba.cut(line))
                tokens = [token for token in tokens if token not in set("\u3000\n 。,:!！“?…”《》，；—（）-：？^~`[]|()")]
                res += tokens
    return res

example_file = get_file_path("体育")[0]
get_tokens(example_file)[:10]

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/97/5s9vc3hs2nj936lfdqch4zsr0000gn/T/jieba.cache
Loading model cost 0.400 seconds.
Prefix dict has been built successfully.


['马德兴', '地方官', '为', '全运', '不让', '新星', '留洋', '只', '为', '自己']

In [4]:
topic_list = {}
for topic in categories:
    file_paths = get_file_path(topic)
    topic_list[topic] = [get_tokens(file) for file in tqdm(file_paths)]

100%|██████████| 2673/2673 [00:05<00:00, 467.21it/s]
100%|██████████| 6517/6517 [00:15<00:00, 420.49it/s]
100%|██████████| 8387/8387 [00:33<00:00, 252.44it/s]


In [5]:
word_to_id, id_to_word = {}, {}
for value in topic_list.values():
    for text in value:
        for word in text:
            if word not in word_to_id:
                new_id = len(word_to_id)
                word_to_id[word] = new_id
                id_to_word[new_id] = word

n_words = len(word_to_id)
print(n_words)

196271


In [6]:
def text_to_tensor(text):
    tensor = torch.zeros(len(text), dtype=torch.long)
    for li, word in enumerate(text):
        try:
            ind = word_to_id[word]
        except KeyError:
            ind = n_words - 1
        tensor[li] = ind
    return tensor

example_file = get_file_path("时尚")[0]
text_to_tensor(get_tokens(example_file))

tensor([ 6767,    89,   496, 20233,  9530, 40450,  4719,  2110,  5455,  6047,
        40451, 40452,     5,   200,  9231,   325,   414,  2202,   414,   754,
        13855,     5,  2152,   238,   144,  4719,  3965,    46,  1028,  2152,
          118,  1673,   237,   278,   118,    59, 40453,   754,   585,  4807,
          249,   247,  6767,  4557,     5, 20233,    66,     5,  4042,   241,
         4274,     5,  1163,   348,   840,   241,   118, 12748,     5,  6903,
          348,   754,   277,  3600,     5,  2152,   200,   247, 20831,     5,
        17694, 40454,  2071, 23879,   927,  2795,   848,  3628,     5, 40455,
          225, 21213,  1228,   754,   277, 40456, 17931,     5, 20233,  8467,
         4807,   305,   249,   247,  4923,   483,    25,  2897,  1578,  1012,
          328, 20233,  6804,  1507, 35025,   247,   148,  3628,     5,    29,
          519,     5, 40455,   225,  2795,   848,  5953, 18768,     5, 21213,
         1024,  3600,    11,   118,   328,   483,  5953,  3710, 

In [7]:
all_data = []
    
for ind, value in enumerate(topic_list.values()):
    for tokens in tqdm(value):
        all_data.append((text_to_tensor(tokens), torch.tensor([ind], dtype=torch.long)))
    
random.shuffle(all_data)
data_len = len(all_data)
split_ratio = 0.8

train_data = all_data[:int(data_len * split_ratio)]
test_data = all_data[int(data_len * split_ratio):]
print("Train data size:", len(train_data))
print("Test data size:", len(test_data))

100%|██████████| 2673/2673 [00:01<00:00, 1754.63it/s]
100%|██████████| 6517/6517 [00:04<00:00, 1545.03it/s]
100%|██████████| 8387/8387 [00:09<00:00, 926.74it/s] 

Train data size: 14061
Test data size: 3516





## 2. 构造数据集

In [8]:
class LSTM(nn.Module):
    def __init__(self, word_count, embedding_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(word_count, embedding_size)
        self.LSTM = nn.LSTM(embedding_size, hidden_size, num_layers=2, bidirectional=True, batch_first=True)
        self.cls = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=0)
        
    def forward(self, input_tensor):
        word_vector = self.embedding(input_tensor)
        output = self.LSTM(word_vector)[0][0][len(input_tensor)-1]
        output = output.reshape(2, -1).sum(axis=0)
        output = self.cls(output)
        output = self.softmax(output)
        return output

In [9]:
def run_lstm(rnn, input_tensor):
    output = rnn(input_tensor.unsqueeze(dim=0))
    return output

In [10]:
def train(rnn, criterion, input_tensor, category_tensor):
    rnn.zero_grad()
    output = run_lstm(rnn, input_tensor)
    loss = criterion(output.unsqueeze(dim=0), category_tensor)
    loss.backward()

    # 根据梯度更新模型的参数
    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, loss.item()

In [11]:
def evaluate(lstm, input_tensor):
    with torch.no_grad():
        output = run_lstm(lstm, input_tensor)
        return output

In [12]:
from tqdm import tqdm

epoch = 10
embedding_size = 128
n_hidden = 64
n_categories = 3
learning_rate = 0.005
lstm = LSTM(n_words, embedding_size, n_hidden, n_categories)
criterion = nn.NLLLoss()
loss_sum = 0
all_losses = []
plot_every = 100
for e in range(epoch):
    for ind, (title_tensor, label) in enumerate(tqdm(train_data)):
        output, loss = train(lstm, criterion, title_tensor, label)
        loss_sum += loss
        if ind % plot_every == 0:
            all_losses.append(loss_sum / plot_every)
            loss_sum = 0
    c = 0
    for title, category in tqdm(test_data):
        output = evaluate(lstm, title)
        topn, topi = output.topk(1)
        if topi.item() == category[0].item():
            c += 1
    print('accuracy', c / len(test_data))

100%|██████████| 14061/14061 [4:19:42<00:00,  1.11s/it]     
100%|██████████| 3516/3516 [34:57<00:00,  1.68it/s]   


accuracy 0.8734357224118316


100%|██████████| 14061/14061 [2:58:24<00:00,  1.31it/s]    
100%|██████████| 3516/3516 [01:38<00:00, 35.60it/s]


accuracy 0.9169510807736063


100%|██████████| 14061/14061 [46:48<00:00,  5.01it/s] 
100%|██████████| 3516/3516 [01:38<00:00, 35.65it/s]


accuracy 0.9416951080773607


100%|██████████| 14061/14061 [47:11<00:00,  4.97it/s] 
100%|██████████| 3516/3516 [01:39<00:00, 35.46it/s]


accuracy 0.9434015927189988


100%|██████████| 14061/14061 [47:03<00:00,  4.98it/s] 
100%|██████████| 3516/3516 [01:38<00:00, 35.71it/s]


accuracy 0.9527872582480091


100%|██████████| 14061/14061 [46:54<00:00,  5.00it/s] 
100%|██████████| 3516/3516 [01:38<00:00, 35.53it/s]


accuracy 0.947098976109215


100%|██████████| 14061/14061 [47:09<00:00,  4.97it/s] 
100%|██████████| 3516/3516 [01:38<00:00, 35.57it/s]


accuracy 0.9496587030716723


100%|██████████| 14061/14061 [47:25<00:00,  4.94it/s] 
100%|██████████| 3516/3516 [01:39<00:00, 35.28it/s]


accuracy 0.9553469852104665


100%|██████████| 14061/14061 [47:22<00:00,  4.95it/s] 
100%|██████████| 3516/3516 [01:39<00:00, 35.50it/s]


accuracy 0.9573378839590444


100%|██████████| 14061/14061 [4:30:59<00:00,  1.16s/it]     
100%|██████████| 3516/3516 [01:38<00:00, 35.64it/s]

accuracy 0.9556313993174061



