# 数据处理模块

## 目录

* 数据集加载
* 读取标签和数据
* 创建word2id
* 将数据转化成id

In [1]:
from torch.utils import data
import os
import csv
import nltk
import numpy as np

In [7]:
# 数据集加载
f = open("./data/AG/train.csv")
rows = csv.reader(f,delimiter=',', quotechar='"')
rows = list(rows)
rows[0:5]

[['3',
  'Wall St. Bears Claw Back Into the Black (Reuters)',
  "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."],
 ['3',
  'Carlyle Looks Toward Commercial Aerospace (Reuters)',
  'Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],
 ['3',
  "Oil and Economy Cloud Stocks' Outlook (Reuters)",
  'Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.'],
 ['3',
  'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)',
  'Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.'],
 ['3',
  'Oil pr

In [10]:
# 读取标签和数据
n_gram = 2 
lowercase = True
label = []
datas = []
for row in rows:
    label.append(int(row[0])-1)
    txt = " ".join(row[1:])
    if lowercase:
        txt = txt.lower()
    txt = nltk.word_tokenize(txt)   # 将句子转化成词
    new_txt=  []
    for i in range(0,len(txt)):
        for j in range(n_gram):   # 添加n-gram词
            if j<=i:
                new_txt.append(" ".join(txt[i-j:i+1]))
    datas.append(new_txt)
print (label[0:5])
print (datas[0:5])

[2, 2, 2, 2, 2]
[['wall', 'st.', 'wall st.', 'bears', 'st. bears', 'claw', 'bears claw', 'back', 'claw back', 'into', 'back into', 'the', 'into the', 'black', 'the black', '(', 'black (', 'reuters', '( reuters', ')', 'reuters )', 'reuters', ') reuters', '-', 'reuters -', 'short-sellers', '- short-sellers', ',', 'short-sellers ,', 'wall', ', wall', 'street', 'wall street', "'s", "street 's", 'dwindling\\band', "'s dwindling\\band", 'of', 'dwindling\\band of', 'ultra-cynics', 'of ultra-cynics', ',', 'ultra-cynics ,', 'are', ', are', 'seeing', 'are seeing', 'green', 'seeing green', 'again', 'green again', '.', 'again .'], ['carlyle', 'looks', 'carlyle looks', 'toward', 'looks toward', 'commercial', 'toward commercial', 'aerospace', 'commercial aerospace', '(', 'aerospace (', 'reuters', '( reuters', ')', 'reuters )', 'reuters', ') reuters', '-', 'reuters -', 'private', '- private', 'investment', 'private investment', 'firm', 'investment firm', 'carlyle', 'firm carlyle', 'group', 'carlyle g

## Hash技术
<img src="./imgs/hash.png"  width="600" height="600" align="bottom" />

In [4]:
# 得到word2id
min_count = 3
word_freq = {}
for data in datas:   # 首先统计词频，后续通过词频过滤低频词
    for word in data:
        if word_freq.get(word)!=None:  
            word_freq[word]+=1
        else:
            word_freq[word] = 1
word2id = {"<pad>":0,"<unk>":1} 
for word in word_freq:   # 首先构建uni-gram词，因为不需要hash
    if word_freq[word]<min_count or " " in word:
        continue
    word2id[word] = len(word2id)
uniwords_num = len(word2id)
for word in word_freq:  # 构建2-gram以上的词，需要hash
    if word_freq[word]<min_count or " " not in word:
        continue
    word2id[word] = len(word2id)
word2id

{'<pad>': 0,
 '<unk>': 1,
 'wall': 2,
 'st.': 3,
 'bears': 4,
 'claw': 5,
 'back': 6,
 'into': 7,
 'the': 8,
 'black': 9,
 '(': 10,
 'reuters': 11,
 ')': 12,
 '-': 13,
 ',': 14,
 'street': 15,
 "'s": 16,
 'of': 17,
 'are': 18,
 'seeing': 19,
 'green': 20,
 'again': 21,
 '.': 22,
 'carlyle': 23,
 'looks': 24,
 'toward': 25,
 'commercial': 26,
 'aerospace': 27,
 'private': 28,
 'investment': 29,
 'firm': 30,
 'group': 31,
 '\\which': 32,
 'has': 33,
 'a': 34,
 'reputation': 35,
 'for': 36,
 'making': 37,
 'and': 38,
 'plays': 39,
 'in': 40,
 'defense': 41,
 'industry': 42,
 'quietly': 43,
 'bets': 44,
 'on': 45,
 'another': 46,
 'part': 47,
 'market': 48,
 'oil': 49,
 'economy': 50,
 'cloud': 51,
 'stocks': 52,
 "'": 53,
 'outlook': 54,
 'soaring': 55,
 'crude': 56,
 'prices': 57,
 'plus': 58,
 'worries\\about': 59,
 'earnings': 60,
 'expected': 61,
 'over': 62,
 'stock': 63,
 'next': 64,
 'week': 65,
 'during': 66,
 'depth': 67,
 'the\\summer': 68,
 'doldrums': 69,
 'iraq': 70,
 'halts'

In [9]:
print (list(word2id.items())[-20:])
print (len(word2id))

[('whitney scored', 297875), ('lazard files', 297876), ('bosnian serb-run', 297877), ('serb-run half', 297878), (') metameta8080', 297879), ('metameta8080 :', 297880), ('penthouse in', 297881), ('with clement', 297882), ('clippers 113-86', 297883), ('with tcdd', 297884), ('known dioxin', 297885), ('one contained', 297886), ('in agent', 297887), ('who analyzed', 297888), ('analyzed his', 297889), ('lakers 120-116', 297890), ('rafique and', 297891), ('all-star vince', 297892), ('four inmates', 297893), ('nets for', 297894)]
297895


In [6]:
# 将文本中的词都转化成id
max_length = 100
for i,data in enumerate(datas):
    for j,word in enumerate(data):
        if " " not in word:
            datas[i][j] = word2id.get(word,1)
        else:
            datas[i][j] = word2id.get(word, 1)%100000+uniwords_num  # hash函数
    datas[i] = datas[i][0:max_length]+[0]*(max_length-len(datas[i]))
datas[0:5]

[[2,
  3,
  82792,
  4,
  41397,
  5,
  41397,
  6,
  82793,
  7,
  82794,
  8,
  82795,
  9,
  82796,
  10,
  41397,
  11,
  82797,
  12,
  82798,
  11,
  82799,
  13,
  82800,
  1,
  41397,
  14,
  41397,
  2,
  82801,
  15,
  82802,
  16,
  82803,
  1,
  41397,
  17,
  41397,
  1,
  41397,
  14,
  41397,
  18,
  82804,
  19,
  82805,
  20,
  82806,
  21,
  41397,
  22,
  82807,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [23,
  24,
  41397,
  25,
  82808,
  26,
  41397,
  27,
  41397,
  10,
  41397,
  11,
  82797,
  12,
  82798,
  11,
  82799,
  13,
  82800,
  28,
  82809,
  29,
  82810,
  30,
  82811,
  23,
  41397,
  31,
  82812,
  14,
  82813,
  32,
  82814,
  33,
  41397,
  34,
  82815,
  35,
  82816,
  36,
  82817,
  37,
  82818,
  1,
  41397,
  38,
  41397,
  1,
  41397,
  39,
  41397,
 