In [13]:
import numpy as np
from scipy import misc
import scipy
import os
import pylab as plt
import pandas as pd
import random
from keras.utils import to_categorical

class movie_data(object):

    def __init__(self, read_type='b', max_vocab=30000, max_len=500, end_token="<eos>"):
        self.train_pt, self.test_pt = 0, 0
        self.img_path = "../dataset/poster_image/new/"
        self.label_path = "../dataset/MovieGenre_fix_leekeon_total6000.csv"
        self.plot_path = "../dataset/poster_txt/new/"
        self.max_len = max_len
        self.max_vocab = max_vocab

        self.genre_2_labelid = {}
        self.movieid_2_genre = {}
        self.labelid_2_genre = {}

        self.movie_df = pd.read_csv(self.label_path, encoding="ISO-8859-1")

        self.dict_init()

        files = [file for file in os.listdir(self.img_path)
             if os.path.isfile(os.path.join(self.img_path,file))]

        random.seed(777)
        random.shuffle(files)

        self.train_size = int(len(files) * 0.8)
        self.test_size = len(files) - self.train_size

        train_files = files[0:self.train_size]
        test_files = files[self.train_size:len(files)]

        # poster & 결과 레이블
        self.movieid_train, self.x_img_train, self.y_train = self.init_data(train_files)
        self.movieid_test,self.x_img_test, self.y_test = self.init_data(test_files)

        # plot
        #단어 전 처리
        self.w2idx = {end_token: 0, "<unk>": 1}
        self.files_to_word(files)
        self.x_ids_train, self.x_len_train = self.init_plot_data(train_files)
        self.x_ids_test, self.x_len_test = self.init_plot_data(train_files)

        self.data_summary()

    def data_summary(self):
        print('x_train size', len(self.x_img_train))
        print('y_train size', len(self.y_train))
        print('x_test  size', len(self.x_img_test))
        print('y_test  size', len(self.y_test))

    def preprocess(self, img, size=(134, 91)):
        img = scipy.misc.imresize(img, size)
        img = img.astype(np.float32)
        img = (img / 127.5) - 1.

        return img

    def min_max_scaling(self, x, size=(134, 91)):
        x = scipy.misc.imresize(x, size)
        x_np = np.asarray(x)

        return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-7)

    def reverse_min_max_scaling(org_x, x):
        org_x_np = np.asarray(org_x)
        x_np = np.asarray(x)
        return (x_np * (org_x_np.max() - org_x_np.min() + 1e-7)) + org_x_np.min()

    def dict_init(self):
        for movie in self.movie_df.values:
            self.movieid_2_genre[movie[0]] = movie[7]

            if movie[7] not in self.genre_2_labelid:
                self.genre_2_labelid[movie[7]] = len(self.genre_2_labelid)

        for i, genre in enumerate(self.genre_2_labelid):
            self.labelid_2_genre[i] = genre

    def get_train(self, batch_size=20):
        pt = self.train_pt
        self.train_pt = (self.train_pt + batch_size) % self.train_size
        return self.movieid_train[pt: pt+batch_size], self.x_img_train[pt: pt+batch_size], self.y_train[pt: pt+batch_size]

    def get_test(self, batch_size=20):
        pt = self.test_pt
        self.test_pt = (self.test_pt + batch_size) % self.test_size
        return self.movieid_test[pt: pt+batch_size], self.x_img_test[pt: pt+batch_size], self.y_test[pt: pt+batch_size]

    def get_train_dataset(self):
        return self.x_img_train, to_categorical(self.y_train,4)

    def get_test_dataset(self):
        return self.x_img_test, to_categorical(self.y_test,4)

    def to_rgb2(self, im):
    # as 1, but we use broadcasting in one line
        w, h = im.shape
        ret = np.empty((w, h, 3), dtype=np.uint8)
        ret[:, :, :] = im[:, :, np.newaxis]
        return ret

    def init_data(self, files):
        movieid = []
        x = []
        y = []
        for file in files:
            fid = int(file.split('.')[0])
            genre = self.movieid_2_genre[fid]
            # if(genre=='Adventure'
            #    or genre=='Documentary'
            #    or genre=='Romance'
            #    or genre=='Horror'):
            img = self.min_max_scaling(misc.imread(self.img_path+file))
            if(img.ndim==2):
                img = self.to_rgb2(img)
            img = img[...,:3]
            x.append(img)
            # fid = int(file.split('.')[0])
            lid = self.genre_2_labelid[self.movieid_2_genre[fid]]
            movieid.append(fid)
            y.append(lid)
        return movieid, x, y

    def get_w2idx(self, word):
        return 1 if word not in self.w2idx else self.w2idx[word]
        
    def files_to_word(self, files):
        lines = []
        for file in files:
            filename = file.split('.')[0] + '.txt'
            with open(self.plot_path + filename, "r", encoding="utf-8") as fin:
                lines.append(fin.readline())

        cnt = {}
        for line in lines:
            for word in line.split():
                if word in cnt:
                    cnt[word] += 1
                else:
                    cnt[word] = 1
        cnt_sort = sorted(cnt.items(), key=lambda cnt:cnt[1], reverse=True)
        for word, count in cnt_sort:
            self.w2idx[word] = len(self.w2idx)
            if self.w2idx == self.max_vocab:
                break

    def init_plot_data(self, files):
        lines = []
        for file in files:
            filename = file.split('.')[0] + '.txt'
            with open(self.plot_path + filename, "r", encoding="utf-8") as fin:
                lines.append(fin.readline())

        length, ids = [], []
        for line in lines:
            id = np.zeros(self.max_len, dtype=np.int32)
            line += " <eos>"
            words = line.split()
            for i, word in enumerate(words):
                if i == self.max_len:
                    break
                if word not in self.w2idx and len(self.w2idx) < self.max_vocab:
                    self.w2idx[word] = len(self.w2idx)
                id[i] = self.get_w2idx(word)
            ids.append(id)
            length.append(i)
        return np.array(ids), np.array(length)


In [14]:
data = movie_data()

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


x_train size 4800
y_train size 4800
x_test  size 1200
y_test  size 1200


[189142,
 104812,
 1854364,
 2256703,
 428251,
 1003080,
 1637612,
 71634,
 1859621,
 91083,
 1772925,
 3548962,
 1592527,
 2883352,
 3021058,
 3062074,
 379060,
 288477,
 59260,
 112281,
 88208,
 2752724,
 2316801,
 102522,
 1344811,
 371576,
 93468,
 113932,
 1832368,
 1670995,
 2072214,
 94320,
 291502,
 1772407,
 4786282,
 290326,
 370244,
 79450,
 31058,
 96754,
 1951265,
 24894,
 382189,
 100633,
 97737,
 444759,
 80764,
 790712,
 3093520,
 82379,
 86050,
 88771,
 1365474,
 4092686,
 387037,
 1550557,
 204640,
 805570,
 2417154,
 2104994,
 976026,
 81633,
 4026600,
 338977,
 3021354,
 96874,
 94701,
 2075340,
 3253624,
 106223,
 112515,
 1985019,
 78960,
 58006,
 436445,
 299988,
 338427,
 2294916,
 64639,
 430674,
 56172,
 4065414,
 326820,
 1433819,
 2505856,
 411542,
 141861,
 261983,
 218822,
 2852470,
 1482393,
 1273222,
 2992302,
 2043814,
 3626742,
 2387433,
 122541,
 5157030,
 2976354,
 484175,
 50327,
 448993,
 3573598,
 1322306,
 4131208,
 136244,
 68421,
 62622,
 16306