# 数据处理模块

## 目录

* 数据集加载
* 构建word2id并去除低频词
* 构建共现矩阵
* 生成训练集
* 保存结果

In [1]:
from torch.utils import data
import os
import numpy as np
import pickle

In [2]:
min_count = 50

In [3]:
# 数据集加载
data = open("./data/text8.txt").read()
data = data.split()
# 构建word2id并去除低频词
word2freq = {}
for word in data:
    if word2freq.get(word)!=None:
        word2freq[word] += 1
    else:
        word2freq[word] = 1
word2id = {}
for word in word2freq:
    if word2freq[word]<min_count:
        continue
    else:
        if word2id.get(word)==None:
            word2id[word]=len(word2id)
print (len(word2id))
word2id

18497


{'anarchism': 0,
 'originated': 1,
 'as': 2,
 'a': 3,
 'term': 4,
 'of': 5,
 'abuse': 6,
 'first': 7,
 'used': 8,
 'against': 9,
 'early': 10,
 'working': 11,
 'class': 12,
 'radicals': 13,
 'including': 14,
 'the': 15,
 'english': 16,
 'revolution': 17,
 'and': 18,
 'sans': 19,
 'french': 20,
 'whilst': 21,
 'is': 22,
 'still': 23,
 'in': 24,
 'pejorative': 25,
 'way': 26,
 'to': 27,
 'describe': 28,
 'any': 29,
 'act': 30,
 'that': 31,
 'violent': 32,
 'means': 33,
 'destroy': 34,
 'organization': 35,
 'society': 36,
 'it': 37,
 'has': 38,
 'also': 39,
 'been': 40,
 'taken': 41,
 'up': 42,
 'positive': 43,
 'label': 44,
 'by': 45,
 'self': 46,
 'defined': 47,
 'anarchists': 48,
 'word': 49,
 'derived': 50,
 'from': 51,
 'greek': 52,
 'without': 53,
 'ruler': 54,
 'chief': 55,
 'king': 56,
 'political': 57,
 'philosophy': 58,
 'belief': 59,
 'rulers': 60,
 'are': 61,
 'unnecessary': 62,
 'should': 63,
 'be': 64,
 'abolished': 65,
 'although': 66,
 'there': 67,
 'differing': 68,
 'inte

In [4]:
# 构建共现矩阵
vocab_size = len(word2id)
comat = np.zeros((vocab_size,vocab_size))
print(comat.shape)

(18497, 18497)


In [5]:
window_size = 2

In [6]:
for i in range(len(data)):
    if i%1000000==0:
        print (i,len(data))
    if word2id.get(data[i])==None:
        continue
    w_index = word2id[data[i]]
    for j in range(max(0,i-window_size),min(len(data),i+window_size+1)):
        if word2id.get(data[j]) == None or i==j:
            continue
        u_index = word2id[data[j]]
        comat[w_index][u_index]+=1
comat

0 17005207
1000000 17005207
2000000 17005207
3000000 17005207
4000000 17005207
5000000 17005207
6000000 17005207
7000000 17005207
8000000 17005207
9000000 17005207
10000000 17005207
11000000 17005207
12000000 17005207
13000000 17005207
14000000 17005207
15000000 17005207
16000000 17005207
17000000 17005207


array([[2.4000e+01, 1.0000e+00, 2.7000e+01, ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00],
       [1.0000e+00, 0.0000e+00, 5.6000e+01, ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00],
       [2.7000e+01, 5.6000e+01, 1.7426e+04, ..., 1.0000e+00, 0.0000e+00,
        3.0000e+00],
       ...,
       [0.0000e+00, 0.0000e+00, 1.0000e+00, ..., 2.0000e+00, 0.0000e+00,
        0.0000e+00],
       [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00],
       [0.0000e+00, 0.0000e+00, 3.0000e+00, ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00]])

In [7]:
coocs = np.transpose(np.nonzero(comat))
coocs

array([[    0,     0],
       [    0,     1],
       [    0,     2],
       ...,
       [18496, 15009],
       [18496, 15510],
       [18496, 18403]])

In [8]:
# 生成训练集
labels = []
for i in range(len(coocs)):
    if i%1000000==0:
        print (i,len(coocs))
    labels.append(comat[coocs[i][0]][coocs[i][1]])
labels = np.array(labels)
print (labels.shape)

0 9190921
1000000 9190921
2000000 9190921
3000000 9190921
4000000 9190921
5000000 9190921
6000000 9190921
7000000 9190921
8000000 9190921
9000000 9190921
(9190921,)


In [9]:
labels

array([24.,  1., 27., ...,  1.,  1.,  2.])

In [12]:
# 保存结果
np.save("./data/data.npy",coocs)
np.save("./data/label.npy",labels)
pickle.dump(word2id,open("./data/word2id","wb"))