In [42]:
import os
import random
import numpy as np

class plot_data(object):
    def __init__(self, path="./dataset/poster_txt", max_vocab=20000, max_len=90, end_token="<eos>"):
        self.train_pt, self.val_pt, self.test_pt = 0, 0, 0
        self.path = path
        self.max_len = max_len
        self.max_vocab = max_vocab

        self.w2idx = {end_token: 0, "<unk>": 1}
        self.x_ids, self.x_len, self.x_label = self.files_to_ids(path)
        self.vocab_size = len(self.w2idx)


        self.train_size = int(len(self.x_ids) * 0.8)
        self.test_size = len(self.x_ids) - self.train_size

        self.train_ids, self.train_len, self.train_label = \
        self.x_ids[0:self.train_size], self.x_len[0:self.train_size], self.x_label[0:self.train_size]

        self.test_ids, self.test_len, self.test_label = \
        self.x_ids[self.train_size:-1], self.x_len[self.train_size:-1], self.x_label[self.train_size:-1]

        self.idx2w = {}
        for word in self.w2idx:
            self.idx2w[self.w2idx[word]] = word

    def get_w2idx(self, word):
        return 1 if word not in self.w2idx else self.w2idx[word]

    def files_to_ids(self, path):
        self.adventure_list = os.listdir(path + "/Adventure")
        self.documentary_list = os.listdir(path + "/Documentary")
        self.horror_list = os.listdir(path + "/Horror")
        self.romance_list = os.listdir(path + "/Romance")

        size = min(len(self.adventure_list), len(self.documentary_list), len(self.horror_list), len(self.romance_list))
        lines = []
        for i in range(size):
            with open(path + "/Adventure/" + self.adventure_list[i], "r", encoding="utf-8") as fin:
                lines.append(fin.readline())
            with open(path + "/Documentary/" + self.documentary_list[i], "r", encoding="utf-8") as fin:
                lines.append(fin.readline())
            with open(path + "/Horror/" + self.horror_list[i], "r", encoding="utf-8") as fin:
                lines.append(fin.readline())
            with open(path + "/Romance/" + self.romance_list[i], "r", encoding="utf-8") as fin:
                lines.append(fin.readline())

        cnt = {}
        for line in lines:
            for word in line.split():
                if word in cnt:
                    cnt[word] += 1
                else:
                    cnt[word] = 1
        cnt_sort = sorted(cnt.items(), key=lambda cnt:cnt[1], reverse=True)
        for word, count in cnt_sort:
            self.w2idx[word] = len(self.w2idx)
            if self.w2idx == self.max_vocab:
                break

        #random suffling
        random.seed(777)
        random.shuffle(lines)

        length, ids, label = [], [], []
        for num, line in enumerate(lines):
            id = np.zeros(self.max_len, dtype=np.int32)
            line += " <eos>"
            words = line.split()
            for i, word in enumerate(words):
                if i == self.max_len:
                    break
                if word not in self.w2idx and len(self.w2idx) < self.max_vocab:
                    self.w2idx[word] = len(self.w2idx)
                id[i] = self.get_w2idx(word)
            ids.append(id)
            length.append(i)
            label.append(num % 2)

        return np.array(ids), np.array(length), np.array(label)

    def get_train(self, batch_size=20):
        pt = self.train_pt
        self.train_pt = (self.train_pt + batch_size) % self.train_size
        return self.train_ids[pt: pt+batch_size], self.train_len[pt: pt+batch_size], self.train_label[pt: pt+batch_size]

    def get_test(self, batch_size=20):
        pt = self.test_pt
        self.test_pt = (self.test_pt + batch_size) % self.test_size
        return self.test_ids[pt: pt+batch_size], self.test_len[pt: pt+batch_size], self.test_label[pt: pt+batch_size]


In [43]:
data = plot_data()

In [68]:
import os
path="./dataset/poster_txt"

lines = []
lines1 = []
lines2 = []
lines3 = []
lines4 = []

size = min(len(data.adventure_list), len(data.documentary_list), len(data.horror_list), len(data.romance_list))

for i in range(size):
    with open(path + "/Adventure/" + data.adventure_list[i], "r", encoding="utf-8") as fin:
        lines1.append(fin.readline())
#         lines.append(fin.readline())
    with open(path + "/Documentary/" + data.documentary_list[i], "r", encoding="utf-8") as fin:
        lines2.append(fin.readline())
#         lines.append(fin.readline())
    with open(path + "/Horror/" + data.horror_list[i], "r", encoding="utf-8") as fin:
        lines3.append(fin.readline())
#         lines.append(fin.readline())
    with open(path + "/Romance/" + data.romance_list[i], "r", encoding="utf-8") as fin:
        lines4.append(fin.readline())
#         lines.append(fin.readline())

In [63]:
cnt1 = {}
end_token="<eos>"
w2idx = {end_token: 0, "<unk>": 1}
x = 0
for line in lines:
    for word in line.split():
        if word in cnt1:
            cnt1[word] += 1
        else:
            cnt1[word] = 1
cnt_sort = sorted(cnt1.items(), key=lambda cnt:cnt[1], reverse=True)
for word, count in cnt_sort:
    w2idx[word] = len(w2idx)
        
    

In [65]:
len(w2idx)

60119

Adventure 단어수 : 24310
Documentary 단어수 : 24594
Horror 단어수 : 19456
Romance 단어수 : 23960
총 단어수 : 60119


전체에서 가장 긴 sentence의 단어 개수 : 86
Adventure 가장 긴 sentence의 단어 개 : 61
Documentary 가장 긴 sentence의 단어 개 : 58
Horror 가장 긴 sentence의 단어 개 : 81
Romance 가장 긴 sentence의 단어 개 : 86

In [75]:

max_word_cnt = 0
for line in lines4:
    cnt = 0
    for word in line.split():
        cnt += 1
    if (cnt > max_word_cnt):
        max_word_cnt = cnt
print(cnt)        

86


In [67]:
cnt

86