In [1]:
!ls

 01-Fasttext			        data_process.ipynb
 02-TextCNN			        data_utils.py
 04-TextRCNN			        images
'05-Hierarchical Attention Networks '   README.md
 06-memory-networks		        utils.py
'07-attention is all your need'


In [3]:
!pip install wget

[31mdistributed 1.21.8 requires msgpack, which is not installed.[0m
[33mYou are using pip version 10.0.1, however version 18.0 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [4]:
import os
import wget
import tarfile
import re
from nltk.tokenize import word_tokenize
import collections
import pandas as pd
import pickle
import numpy as np

In [33]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/panxie/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

### 下载数据集

In [5]:
def download_dbpedia():
    dbpedia_url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"

    wget.download(dbpedia_url)
    with tarfile.open("dbpedia_csv.tar.gz", "r:gz") as tar:
        tar.extractall()

In [6]:
download_dbpedia()

In [7]:
!ls

 01-Fasttext			        data_utils.py
 02-TextCNN			        dbpedia_csv
 04-TextRCNN			        dbpedia_csv.tar.gz
'05-Hierarchical Attention Networks '   images
 06-memory-networks		        README.md
'07-attention is all your need'         utils.py
 data_process.ipynb


In [8]:
!ls dbpedia_csv/

classes.txt  readme.txt  test.csv  train.csv


In [21]:
train_df = pd.read_csv("dbpedia_csv/train.csv", names=["classes", "title", "content"])

In [22]:
print(train_df.shape)

(560000, 3)


In [23]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 560000 entries, 0 to 559999
Data columns (total 3 columns):
classes    560000 non-null int64
title      560000 non-null object
content    560000 non-null object
dtypes: int64(1), object(2)
memory usage: 12.8+ MB


In [24]:
train_df.head(10)

Unnamed: 0,classes,title,content
0,1,E. D. Abbott Ltd,Abbott of Farnham E D Abbott Limited was a Br...
1,1,Schwan-Stabilo,Schwan-STABILO is a German maker of pens for ...
2,1,Q-workshop,Q-workshop is a Polish company located in Poz...
3,1,Marvell Software Solutions Israel,Marvell Software Solutions Israel known as RA...
4,1,Bergan Mercy Medical Center,Bergan Mercy Medical Center is a hospital loc...
5,1,The Unsigned Guide,The Unsigned Guide is an online contacts dire...
6,1,Rest of the world,Within sports and games played at the interna...
7,1,Globoforce,Globoforce is a multinational company co-head...
8,1,Rompetrol,The Rompetrol Group N.V. is a Romanian oil co...
9,1,Wave Accounting,Wave is the brand name for a suite of online ...


### 词典
- 使用nltk分词
- 使用正则化去除特殊符号

In [26]:
contents = train_df['content']

In [31]:
def clean_str(text):
    # 正则化处理特殊字符
    text = re.sub(r"[^A-Za-z0-9(),!?\'`\"]", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.strip().lower()

    return text

In [34]:
words = []
train_data = []
for content in contents:
    for word in word_tokenize(clean_str(content)):
        
        words.append(word)

In [35]:
print(len(words))

27518975


In [36]:
words[:10]

['abbott',
 'of',
 'farnham',
 'e',
 'd',
 'abbott',
 'limited',
 'was',
 'a',
 'british']

In [37]:
word_counter = collections.Counter(words)  # dict

In [42]:
print(len(word_counter), type(word_counter), word_counter['of'], word_counter['a'], word_counter['abbott'])

563260 <class 'collections.Counter'> 878840 772644 218


In [43]:
word_counter = word_counter.most_common()  # list

In [46]:
print(len(word_counter), word_counter[:5])

563260 [('the', 1666524), ('in', 939594), ('of', 878840), ('a', 772644), ('is', 761186)]


In [53]:
word_dict = {}
word_dict['<pad>'] = 0
word_dict['<unk>'] = 1
word_dict['<eos>'] = 2
for word,_ in word_counter:
    word_dict[word] = len(word_dict)  ## 机智

In [55]:
word_dict['a'], word_dict['<unk>'],word_dict['<eos>']

(6, 1, 2)

### 得到预处理后的训练集

In [56]:
# Shuffle
train_df = train_df.sample(frac=1)

In [57]:
train_df.head()

Unnamed: 0,classes,title,content
245182,7,Luther House,Luther House is a historic house at 177 Marke...
31308,1,Cision,Cision AB is a Swedish software company. Cisi...
185153,5,Edward Baigent,Edward Baigent (22 June 1813 – 9 November 189...
379573,10,Turritella,Turritella is a genus of medium-sized sea sna...
154946,4,Peter Rafferty,Peter Rafferty was a Northern Irish footballe...


In [58]:
# train data 中每一行分词
x = list(map(lambda d: word_tokenize(clean_str(d)), train_df["content"]))

In [61]:
print(x[:2])

[['luther', 'house', 'is', 'a', 'historic', 'house', 'at', '177', 'market', 'street', 'in', 'swansea', 'massachusetts', 'the', 'house', 'was', 'built', 'in', '1740', 'and', 'added', 'to', 'the', 'national', 'historic', 'register', 'in', '1990'], ['cision', 'ab', 'is', 'a', 'swedish', 'software', 'company', 'cision', 'has', 'offices', 'in', 'europe', 'north', 'america', 'and', 'asia', 'as', 'of', '2011', 'it', 'has', 'revenues', 'of', 'sek', '1', 'billion']]


In [63]:
# 以第一行为例，得到每个次在 word_dict 的index， 不存在的词赋予 word_dict['<unk>']
print(list(map(lambda word: word_dict.get(word, word_dict['<unk>']), x[0])))

[6353, 68, 7, 6, 84, 68, 20, 12246, 866, 290, 4, 7232, 418, 3, 68, 11, 59, 4, 10947, 8, 530, 13, 3, 60, 84, 254, 4, 476]


In [64]:
# 把 list(map(lambda word: word_dict.get(word, word_dict['<unk>']), d) 看作一个参数为d函数
# list(map(lambda d: f, x))
x = list(map(lambda d:list(map(lambda word: word_dict.get(word, word_dict['<unk>']), d)), x))

In [67]:
print(x[0], len(x))

[6353, 68, 7, 6, 84, 68, 20, 12246, 866, 290, 4, 7232, 418, 3, 68, 11, 59, 4, 10947, 8, 530, 13, 3, 60, 84, 254, 4, 476] 560000


In [68]:
# 加上结束词
x = list(map(lambda d: d + [word_dict["<eos>"]], x))

In [69]:
print(x[0], len(x))

[6353, 68, 7, 6, 84, 68, 20, 12246, 866, 290, 4, 7232, 418, 3, 68, 11, 59, 4, 10947, 8, 530, 13, 3, 60, 84, 254, 4, 476, 2] 560000


In [83]:
# 设定一个最长序列值
x = list(map(lambda d:d[:30], x[:2]))

In [84]:
print(x[:2])

[[6353, 68, 7, 6, 84, 68, 20, 12246, 866, 290], [146319, 3087, 7, 6, 702, 840, 56, 146319, 32, 1121]]


In [90]:
# 不足这个长度的补零
x = list(map(lambda d: d + (15 - len(d)) * [word_dict["<pad>"]], x))

In [91]:
print(x)

[[6353, 68, 7, 6, 84, 68, 20, 12246, 866, 290, 0, 0, 0, 0, 0], [146319, 3087, 7, 6, 702, 840, 56, 146319, 32, 1121, 0, 0, 0, 0, 0]]


#### 真实标签

In [88]:
y = list(map(lambda d: d - 1, list(train_df["classes"])))

In [89]:
y[:10]

[6, 0, 4, 9, 3, 1, 8, 1, 5, 1]

### 迭代处理数据

In [None]:
def batch_iter(inputs, outputs, batch_size, num_epochs):
    inputs = np.array(inputs)
    outputs = np.array(outputs)

    num_batches_per_epoch = (len(inputs) - 1) // batch_size + 1
    for epoch in range(num_epochs):
        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, len(inputs))
            yield inputs[start_index:end_index], outputs[start_index:end_index]