In [1]:
import os
import warnings
warnings.filterwarnings("ignore")

from collections import Counter
import numpy as np

import tensorflow as tf

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data()

In [3]:
x_train.shape, len(x_train[0])

((25000,), 218)

In [4]:
_word2idx = tf.keras.datasets.imdb.get_word_index()
word2idx = {w: i+3 for w, i in _word2idx.items()}
word2idx['<pad>'] = 0
word2idx['<start>'] = 1
word2idx['<unk>'] = 2
idx2word = {i: w for w, i in word2idx.items()}

In [5]:
def write_file(f_path, x, y):
    dirname = os.path.dirname(os.path.abspath('./data/train.txt'))
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    
    with open(f_path, 'w',encoding='utf-8') as f:
        for factor, label in zip(x, y):
            content = ' '.join([idx2word[idx] for idx in factor])
            f.write(f"{label}\t{content}\n")

write_file('./data/train.txt', x_train, y_train)
write_file('./data/test.txt', x_test, y_test)

# 创建语料表

In [6]:
counter = Counter()
with open('./data/train.txt', 'r',encoding='utf-8') as f:
    for line in f.readlines():
        line: str = line.strip()
        label, content = line.split('\t')
        content = content.split(' ')

        counter.update([w for w in content])


In [7]:
words = [word for word, freq in counter.most_common() if freq > 10]

if '<pad>' not in words:
    words.append('<pad>')

if '<start>' not in words:
    words.append('<start>')
    
if '<unk>' not in words:
    words.append('<unk>')

In [8]:
len(words)

19501

In [9]:
dirname = "./vocab"
if not os.path.exists(dirname):
        os.makedirs(dirname)

with open(f'{dirname}/word.txt', 'w',encoding='utf-8') as f:
    for w in words:
        f.write(f"{w}\n")
        