新闻主题分类任务
===

以一段新闻报道中的文本描述内容为输入，使用模型帮助我们判断他最优可能属于哪一种类型的新闻，这是典型的文本分类内容，我们这里假设每种类型是互斥的，即文本描述有且只有一种类型

# 1.导入数据

In [4]:
import torch
import os
from torchtext.datasets import text_classification

load_data_path = "./data"
if not os.path.isdir(load_data_path):
    os.mkdir(load_data_path)

train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](root=load_data_path)

ag_news_csv.tar.gz: 11.8MB [00:01, 8.07MB/s]
120000lines [00:04, 25903.66lines/s]
120000lines [00:08, 14700.93lines/s]
7600lines [00:00, 13262.97lines/s]


In [29]:
import pandas as pd
ROOT_PATH = os.path.join(load_data_path, 'ag_news_csv')
train_data = pd.read_csv(os.path.join(ROOT_PATH, 'train.csv'))
train_data.head(10)

Unnamed: 0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
0,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
1,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
2,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
3,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."
4,3,"Stocks End Up, But Near Year Lows (Reuters)",Reuters - Stocks ended slightly higher on Frid...
5,3,Money Funds Fell in Latest Week (AP),AP - Assets of the nation's retail money marke...
6,3,Fed minutes show dissent over inflation (USATO...,USATODAY.com - Retail sales bounced back a bit...
7,3,Safety Net (Forbes.com),Forbes.com - After earning a PH.D. in Sociolog...
8,3,Wall St. Bears Claw Back Into the Black,"NEW YORK (Reuters) - Short-sellers, Wall Stre..."
9,3,Oil and Economy Cloud Stocks' Outlook,NEW YORK (Reuters) - Soaring crude prices plu...


In [30]:
test_data = pd.read_csv(os.path.join(ROOT_PATH, 'test.csv'))
test_data.head(10)

Unnamed: 0,3,Fears for T N pension after talks,Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.
0,4,The Race is On: Second Private Team Sets Launc...,"SPACE.com - TORONTO, Canada -- A second\team o..."
1,4,Ky. Company Wins Grant to Study Peptides (AP),AP - A company founded by a chemistry research...
2,4,Prediction Unit Helps Forecast Wildfires (AP),AP - It's barely dawn when Mike Fitzpatrick st...
3,4,Calif. Aims to Limit Farm-Related Smog (AP),AP - Southern California's smog-fighting agenc...
4,4,Open Letter Against British Copyright Indoctri...,The British Department for Education and Skill...
5,4,Loosing the War on Terrorism,"\\""Sven Jaschan, self-confessed author of the ..."
6,4,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",\\FOAF/LOAF and bloom filters have a lot of i...
7,4,E-mail scam targets police chief,"Wiltshire Police warns about ""phishing"" after ..."
8,4,"Card fraud unit nets 36,000 cards","In its first two years, the UK's dedicated car..."
9,4,Group to Propose New High-Speed Wireless Format,LOS ANGELES (Reuters) - A group of technology...


# 2.网络设计

## 2.1.构建带有Embedding层的文本分类模型

In [31]:
import torch.nn as nn
import torch.nn.functional as F
from jjzhk import device

BATCH_SIZE = 16

In [32]:
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, input):
        embedded = self.embedding(input)
        c = embedded.size(0) // BATCH_SIZE
        embedded_ex = embedded[:BATCH_SIZE * c]
        embedded_ex = embedded_ex.transpose(1, 0).unsqueeze(0)
        embedded_ex = F.avg_pool1d(embedded_ex, kernel_size=c)
        embedded_ex = embedded_ex[0].transpose(1, 0)
        embedded_ex = self.fc(embedded_ex)
        return embedded_ex


In [33]:
VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUM_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)

## 2.2.对数据进行batch处理

In [34]:
def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text =[entry[1] for entry in batch]
    text = torch.cat(text)
    return text, label

In [37]:
batch = [(1, torch.tensor([3, 32, 2, 8])), (0, torch.tensor([3, 45, 21, 6]))]
res = generate_batch(batch)
print(res)

(tensor([ 3, 32,  2,  8,  3, 45, 21,  6]), tensor([1, 0]))


## 2.3.构建训练与验证函数

In [38]:
from torch.utils.data import DataLoader
import time
from torch.utils.data.dataset import random_split
from jjzhk.progressbar import ProgressBar


N_EPOCHS = 10
min_valid_loss = float('inf')
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = random_split(train_dataset, [train_len, len(train_dataset) - train_len])
train_data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_data = DataLoader(sub_valid_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
bar_train = ProgressBar(N_EPOCHS, len(train_data), "loss:%.3f;acc:%.3f")
bar_test = ProgressBar(1, len(valid_data), "loss:%.3f;acc:%.3f")

## 2.4.进行模型训练和验证

In [39]:
for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss = 0
    train_acc = 0

    for i, (text, label) in enumerate(train_data):
        optimizer.zero_grad()
        output = model(text)
        loss = criterion(output, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == label).sum().item()
        bar_train.show(epoch + 1, train_loss / (BATCH_SIZE * (i + 1)), train_acc / (BATCH_SIZE * (i + 1)))

    scheduler.step()

    loss = 0
    acc = 0

    for i, (text, label) in enumerate(valid_data):
        with torch.no_grad():
            output = model(text)
            loss = criterion(output, label)
            loss += loss.item()
            acc += (output.argmax(1) == label).sum().item()
        bar_test.show(1, loss / (BATCH_SIZE * (i + 1)), acc / (BATCH_SIZE * (i + 1)))

Epoch:1 / 10 [****************************************************************************************************] 7125 / 7125 ,loss:0.059;acc:0.644,total=0:00:19
Epoch:1 / 1 [*****************************************************************************************************************************] 375 / 375 ,loss:0.000;acc:0.709,total=0:00:00
Epoch:2 / 10 [****************************************************************************************************] 7125 / 7125 ,loss:0.052;acc:0.700,total=0:00:19
Epoch:1 / 1 [*****************************************************************************************************************************] 375 / 375 ,loss:0.000;acc:0.717,total=0:00:00
Epoch:3 / 10 [****************************************************************************************************] 7125 / 7125 ,loss:0.051;acc:0.711,total=0:00:19
Epoch:1 / 1 [***************************************************************************************************************************

## 2.5.查看embedding层迁入的词向量

In [40]:
print(model.state_dict()['embedding.weight'])

tensor([[ 0.1924,  0.4844,  0.0672,  ..., -0.3984, -0.0840, -0.0872],
        [-0.0263, -0.2393,  0.0808,  ...,  0.4940,  0.0689, -0.4583],
        [-0.1608, -0.0168, -0.0021,  ...,  0.1455, -0.0015, -0.0460],
        ...,
        [ 0.1416,  0.3386, -0.0076,  ..., -0.3229, -0.2432, -0.1422],
        [-0.3922, -0.0447,  0.0746,  ..., -0.0954,  0.0334, -0.2103],
        [ 0.3947, -0.4426,  0.4626,  ...,  0.1671, -0.1298, -0.2242]])
