# 数据处理模块

## 目录

* 数据集加载
* 根据长度排序
* 创建word2id
* 分句
* 将数据转化成id

In [4]:
from torch.utils import data
import os
import nltk
import numpy as np
import pickle
from collections import Counter

In [6]:
# 数据集加载
datas = open("./data/imdb/imdb-test.txt.ss",encoding="utf-8").read().splitlines()
datas = [data.split("		")[-1].split()+[data.split("		")[2]] for data in datas]
datas[0:5]

[['i',
  'knew',
  'that',
  'the',
  'old-time',
  'movie',
  'makers',
  'often',
  '``',
  'borrowed',
  "''",
  'or',
  'outright',
  'plagiarized',
  'from',
  'each',
  'other',
  ',',
  'but',
  'this',
  'is',
  'ridiculous',
  '!',
  '<sssss>',
  'not',
  'only',
  'did',
  'george',
  'albert',
  'smith',
  'make',
  'this',
  'film',
  'in',
  '1899',
  ',',
  'but',
  'and',
  'company',
  'made',
  'a',
  'nearly',
  'identical',
  'film',
  'that',
  'same',
  'year',
  'with',
  'the',
  'same',
  'title',
  '!!!',
  '<sssss>',
  'the',
  'worst',
  'part',
  'about',
  'it',
  'is',
  'that',
  'neither',
  'film',
  'was',
  'all',
  'that',
  'great',
  '.',
  '<sssss>',
  'and',
  ',',
  'of',
  'the',
  'two',
  ',',
  'the',
  'smith',
  'one',
  'is',
  'slightly',
  'less',
  'well',
  'made',
  '.',
  '<sssss>',
  'like',
  'all',
  'movies',
  'of',
  'the',
  '1890s',
  ',',
  'this',
  'one',
  'is',
  'incredibly',
  'brief',
  'and',
  'almost',
  'complete

In [7]:
# 根据长度排序
datas = sorted(datas,key = lambda x:len(x),reverse=True)
labels  = [int(data[-1])-1 for data in datas]
datas = [data[0:-1] for data in datas]
print(labels[0:5])
print (datas[-5:])

[7, 9, 9, 8, 9]
[['one', 'of', 'the', 'best', 'movie', 'musicals', 'ever', 'made', '.', '<sssss>', 'the', 'singing', 'and', 'dancing', 'are', 'excellent', '.'], ['john', 'goodman', 'is', 'excellent', 'in', 'this', 'entertaining', 'portrayal', 'of', 'babe', 'ruth', "'s", 'life', '.'], ['how', 'to', 'this', 'movie', ':', 'disjointed', 'silly', 'unfulfilling', 'story', 'waste', 'of', 'time'], ['simply', 'a', 'classic', '.', '<sssss>', 'scenario', 'and', 'acting', 'are', 'excellent', '.'], ['there', 'were', 'tng', 'tv', 'episodes', 'with', 'a', 'better', 'story', '.']]


In [9]:
# word2id
min_count = 5
word_freq = {}
for data in datas:
    for word in data:
        word_freq[word] = word_freq.get(word,0)+1
word2id = {"<pad>":0,"<unk>":1}
for word in word_freq:
    if word_freq[word]<min_count:
        continue
    else:
        word2id[word] = len(word2id)
word2id

{'<pad>': 0,
 '<unk>': 1,
 'i': 2,
 'only': 3,
 'just': 4,
 'got': 5,
 'around': 6,
 'to': 7,
 'watching': 8,
 'the': 9,
 'movie': 10,
 'today': 11,
 '.': 12,
 '<sssss>': 13,
 'when': 14,
 'it': 15,
 'came': 16,
 'out': 17,
 'in': 18,
 'movies': 19,
 ',': 20,
 'heard': 21,
 'so': 22,
 'many': 23,
 'bad': 24,
 'things': 25,
 'about': 26,
 '...': 27,
 'how': 28,
 'fake': 29,
 'looked': 30,
 'long': 31,
 'winded': 32,
 'and': 33,
 'boring': 34,
 'was': 35,
 'stupid': 36,
 "n't": 37,
 'all': 38,
 'that': 39,
 'great': 40,
 'etc.': 41,
 'list': 42,
 'goes': 43,
 'on': 44,
 'we': 45,
 'can': 46,
 'see': 47,
 'debates': 48,
 'here': 49,
 'as': 50,
 'well': 51,
 'then': 52,
 'there': 53,
 'were': 54,
 'of': 55,
 'course': 56,
 'mixed': 57,
 'critics': 58,
 'but': 59,
 'either': 60,
 'way': 61,
 'wanted': 62,
 'for': 63,
 'myself': 64,
 'judge': 65,
 'always': 66,
 'never': 67,
 'due': 68,
 'lack': 69,
 'time': 70,
 ':': 71,
 '-lrb-': 72,
 'watched': 73,
 'dvd': 74,
 'surprisingly': 75,
 'rathe

In [10]:
# 分句
for i,data in enumerate(datas):
    datas[i] = " ".join(data).split("<sssss>")
    for j,sentence in enumerate(datas[i]):
        datas[i][j] = sentence.split()
datas[0]

[['i',
  'only',
  'just',
  'got',
  'around',
  'to',
  'watching',
  'the',
  'movie',
  'today',
  '.'],
 ['when',
  'it',
  'came',
  'out',
  'in',
  'the',
  'movies',
  ',',
  'i',
  'heard',
  'so',
  'many',
  'bad',
  'things',
  'about',
  'it',
  '...',
  'how',
  'fake',
  'it',
  'looked',
  ',',
  'how',
  'long',
  'winded',
  'and',
  'boring',
  'it',
  'was',
  ',',
  'how',
  'stupid',
  'it',
  'was',
  ',',
  'how',
  'it',
  'was',
  "n't",
  'all',
  'that',
  'great',
  'etc.',
  '.'],
 ['.'],
 ['the',
  'list',
  'goes',
  'on',
  '...',
  'we',
  'can',
  'see',
  'it',
  'in',
  'the',
  'debates',
  'here',
  'as',
  'well',
  '.'],
 ['then',
  'there',
  'were',
  'of',
  'course',
  'mixed',
  'critics',
  'about',
  'it',
  'but',
  'either',
  'way',
  ',',
  'i',
  'wanted',
  'to',
  'see',
  'it',
  'for',
  'myself',
  'to',
  'judge',
  ',',
  'as',
  'always',
  'but',
  'i',
  'never',
  'got',
  'around',
  'to',
  'it',
  ',',
  'due',
  'to',

In [11]:
# 将数据转化为id
max_sentence_length = 100 # 句子必须一样的长度
batch_size = 64 # 每个batch size，每个文档的句子一样多
for i,document in enumerate(datas):
    if i%10000==0:
        print (i,len(datas))
    for j,sentence in enumerate(document):
        for k,word in enumerate(sentence):
            datas[i][j][k] = word2id.get(word,word2id["<unk>"])
        datas[i][j] = datas[i][j][0:max_sentence_length] + \
                      [word2id["<pad>"]]*(max_sentence_length-len(datas[i][j]))
for i in range(0,len(datas),batch_size):
    max_data_length = max([len(x) for x in datas[i:i+batch_size]])
    for j in range(i,min(i+batch_size,len(datas))):
        datas[j] = datas[j] + [[word2id["<pad>"]]*max_sentence_length]*(max_data_length-len(datas[j]))
datas[0]

0 34029
10000 34029
20000 34029
30000 34029


[[2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [14,
  15,
  16,
  17,
  18,
  9,
  19,
  20,
  2,
  21,
  22,
  23,
  24,
  25,
  26,
  15,
  27,
  28,
  29,
  15,
  30,
  20,
  28,
  31,
  32,
  33,
  34,
  15,
  35,
  20,
  28,
  36,
  15,
  35,
  20,
  28,
  15,
  35,
  37,
  38,
  39,
  40,
  41,
  12,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,

In [12]:
import numpy as np
for i in range(0,len(datas),64):
    batch_datas = np.array(datas[i:i+64])
    batch_labels = np.array(labels[i:i+64])
    print (batch_datas.shape)
    print (batch_labels.shape)
    

(64, 205, 100)
(64,)
(64, 83, 100)
(64,)
(64, 77, 100)
(64,)
(64, 73, 100)
(64,)
(64, 67, 100)
(64,)
(64, 75, 100)
(64,)
(64, 84, 100)
(64,)
(64, 98, 100)
(64,)
(64, 69, 100)
(64,)
(64, 75, 100)
(64,)
(64, 79, 100)
(64,)
(64, 93, 100)
(64,)
(64, 88, 100)
(64,)
(64, 76, 100)
(64,)
(64, 64, 100)
(64,)
(64, 65, 100)
(64,)
(64, 65, 100)
(64,)
(64, 76, 100)
(64,)
(64, 69, 100)
(64,)
(64, 63, 100)
(64,)
(64, 58, 100)
(64,)
(64, 68, 100)
(64,)
(64, 61, 100)
(64,)
(64, 58, 100)
(64,)
(64, 76, 100)
(64,)
(64, 56, 100)
(64,)
(64, 47, 100)
(64,)
(64, 59, 100)
(64,)
(64, 49, 100)
(64,)
(64, 48, 100)
(64,)
(64, 52, 100)
(64,)
(64, 47, 100)
(64,)
(64, 44, 100)
(64,)
(64, 44, 100)
(64,)
(64, 56, 100)
(64,)
(64, 50, 100)
(64,)
(64, 55, 100)
(64,)
(64, 46, 100)
(64,)
(64, 48, 100)
(64,)
(64, 60, 100)
(64,)
(64, 59, 100)
(64,)
(64, 52, 100)
(64,)
(64, 50, 100)
(64,)
(64, 46, 100)
(64,)
(64, 46, 100)
(64,)
(64, 44, 100)
(64,)
(64, 47, 100)
(64,)
(64, 38, 100)
(64,)
(64, 48, 100)
(64,)
(64, 91, 100)
(64,)

(64, 14, 100)
(64,)
(64, 14, 100)
(64,)
(64, 15, 100)
(64,)
(64, 18, 100)
(64,)
(64, 16, 100)
(64,)
(64, 12, 100)
(64,)
(64, 14, 100)
(64,)
(64, 13, 100)
(64,)
(64, 16, 100)
(64,)
(64, 13, 100)
(64,)
(64, 17, 100)
(64,)
(64, 15, 100)
(64,)
(64, 15, 100)
(64,)
(64, 13, 100)
(64,)
(64, 13, 100)
(64,)
(64, 15, 100)
(64,)
(64, 12, 100)
(64,)
(64, 16, 100)
(64,)
(64, 14, 100)
(64,)
(64, 22, 100)
(64,)
(64, 17, 100)
(64,)
(64, 14, 100)
(64,)
(64, 13, 100)
(64,)
(64, 13, 100)
(64,)
(64, 11, 100)
(64,)
(64, 13, 100)
(64,)
(64, 14, 100)
(64,)
(64, 11, 100)
(64,)
(64, 12, 100)
(64,)
(64, 12, 100)
(64,)
(64, 13, 100)
(64,)
(64, 15, 100)
(64,)
(64, 12, 100)
(64,)
(64, 16, 100)
(64,)
(64, 13, 100)
(64,)
(64, 14, 100)
(64,)
(64, 13, 100)
(64,)
(64, 13, 100)
(64,)
(64, 14, 100)
(64,)
(64, 14, 100)
(64,)
(64, 12, 100)
(64,)
(64, 19, 100)
(64,)
(64, 12, 100)
(64,)
(64, 13, 100)
(64,)
(64, 12, 100)
(64,)
(64, 12, 100)
(64,)
(64, 14, 100)
(64,)
(64, 11, 100)
(64,)
(64, 10, 100)
(64,)
(64, 14, 100)
(64,)
