In [None]:
from tensorflow.keras import preprocessing
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.layers import Input
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
import h5py
import os
import re
import random
import pickle
# data generator class

  #train,val,test별로 사진코드들을 모아놓은 리스트 생성
def sorted_img_code_list(get_path):
  sorted_list= list()  
  with open(get_path, 'r') as f:
    line = f.read().splitlines()
    for filename in line:
      sorted_list.append(filename)
  return sorted_list 


def img_code_list(train_code_path, val_code_path, test_code_path):
  train_list = sorted_img_code_list(train_code_path)
  val_list = sorted_img_code_list(val_code_path)
  test_list = sorted_img_code_list(test_code_path)
  return train_list,val_list,test_list

def create_tokenizer(seq):
  t = Tokenizer()
    #t.fit_on_texts([line for value in self.seq.values()])
  ALL_text = list()
  for elem in seq.values():
    for line in elem:
      ALL_text.append(line)
  t.fit_on_texts(ALL_text)
  t.word_index['<pad>'] = 0
  t.index_word[0] = '<pad>'
  return t

def get_max(t,seq):
  max_len =0
  for elem in seq.values():
    sequences = list()
    for line in elem:
      encoded_text = t.texts_to_sequences([line])[0]
      sequences.append(encoded_text)
  
    max_expec = max(len(l) for l in sequences)
    if max_len < max_expec :
      max_len = max_expec  
  return max_len



def save_all_seq_data(t,seq,max_len,train_list,val_list,test_list,
                      train_seq_path,val_seq_path,test_seq_path):
  save_seq_data(t,seq,train_list,max_len,train_seq_path)
  save_seq_data(t,seq,val_list,max_len,val_seq_path)
  save_seq_data(t,seq,test_list,max_len,test_seq_path)


def save_seq_data(t,seq,sorted_list,max_len,saved_seq_path): 
  vocab_size = len(t.word_index)+1    
  with h5py.File(saved_seq_path, 'w') as f:
    for elem in sorted_list:
      text = t.texts_to_sequences(seq[elem])
      text = pad_sequences(text, maxlen = max_len,padding = 'post')
      f.create_dataset(elem, data = text)


  #텍스트 정재하고 이를 이미지 파일 코드와 dict로 만들어줌
def sequence_refining(token_path):
  textList = list()
  dic = {}      
  with open(token_path, 'r') as f:
    readed = f.read().splitlines()
    filename =''
    for line in readed:
      filename, disc  = line.split('\t')
      rfilename, num = filename.split('#')
      disc = re.sub('[-=+,#/\?:^$.@*\"※~&%ㆍ!』\\‘|\(\)\[\]\<\>`\'…》]', '', disc)
      disc = "sq "+disc+ " eq"
      if rfilename in dic.keys():
        dic[rfilename].append(disc)
      else :
        dic[rfilename] = [disc]      
  return dic        

#image에서 사전 훈련된 CNN(inception_V3)를 이용해 특징을 뽑은 뒤 h5 포맷으로 저장 
def save_img_feature(sorted_save_path,img_data_directory,CNN_Model,img_list):
  h5 = h5py.File(sorted_save_path,'w')  
  indx=0    
  for img in img_list:      
    img_path = os.path.join(img_data_directory, img)
    loaded_img = image.load_img(img_path, target_size = (299, 299))
    loaded_img = image.img_to_array(loaded_img)
    loaded_img = preprocess_input(loaded_img)
    loaded_img = np.expand_dims(loaded_img, 0)
    feature = CNN_Model.predict(loaded_img)
    #shape = (8,8,2048) -> shape = (1,64,2048)
    feature = tf.reshape(feature, (feature.shape[0],-1,feature.shape[3]))
    #print(feature.shape)  
    if indx % 100 == 0:
      print('processing'+str(indx))
        
    h5.create_dataset(img, data=feature)
    indx= indx+1
  print("complet")
  h5.close()

def save_all_img_feature(img_data_directory,CNN_model,
                         train_list,val_list,test_list,
                         train_feature_path,val_feature_path,test_feature_path):
  save_img_feature(train_feature_path,img_data_directory,CNN_model, train_list)
  save_img_feature(val_feature_path,img_data_directory,CNN_model, val_list)
  save_img_feature(test_feature_path,img_data_directory,CNN_model, test_list)


if __name__ == "__main__":
  #model 생성
  model = InceptionV3()
  CNN_Model = Model(inputs=model.inputs, outputs = model.layers[-3].output)
  base_directory = '/content/gdrive/My Drive/Colab Notebooks/image_captioning_with_attention'
  #전처리된 데이터가 저장될 경로

  #data_Base_Directory = '/content/flickr8k_dataset/Flicker8k_Dataset'
  data_base_directory ='/content'
  #데이터 셋이 있는 경로

  
  #이미지 데이터 경로
  img_data_directory = os.path.join(data_base_directory,'flickr8k_dataset','Flicker8k_Dataset')

  
  saved_data_path = os.path.join(base_directory,'data')
  #이미지 데이터 프로세싱 결과가 저장될 경로
  train_feature_path = os.path.join(saved_data_path, 'train_features.hdf5')
  val_feature_path = os.path.join(saved_data_path, 'val_features.hdf5')
  test_feature_path = os.path.join(saved_data_path, 'test_features.hdf5')

  #시퀀스 데이터 프로세싱 결과가 저장될 경로
  train_seq_path = os.path.join(saved_data_path,'train_sequence.hdf5')
  val_seq_path = os.path.join(saved_data_path,'val_sequence.hdf5')
  test_seq_path = os.path.join(saved_data_path,'test_sequence.hdf5')
  
  
  #분류된 이미지 데이터의 code가 담겨있는 txt파일 경로
  text_path = os.path.join(data_base_directory,'flickr8k_text')
  train_code_path = os.path.join(text_path,'Flickr_8k.trainImages.txt')
  val_code_path = os.path.join(text_path,'Flickr_8k.devImages.txt')
  test_code_path = os.path.join(text_path,'Flickr_8k.testImages.txt')
  
  #이미지 코드에 따라 저장되있는 이미지를 묘사하는 문자열들이 저장된 txt파일 경로
  token_path = os.path.join(text_path,'Flickr8k.token.txt')


  

  #코드 분류
  train_list,val_list,test_list = img_code_list(train_code_path, val_code_path, test_code_path)
  
  #이미지 feature 저장
  save_all_img_feature(img_data_directory,CNN_Model,
                       train_list,val_list,test_list,
                         train_feature_path,val_feature_path,test_feature_path)
  
  
  #텍스트 정제
  sequence = sequence_refining(token_path)
  
  #tokenizer 생성
  tokenizer = create_tokenizer(sequence)

  #max_len 구하기
  max_len = get_max(tokenizer,sequence)
  
  #sequence data 저장
  
  save_all_seq_data(tokenizer,sequence,max_len,
                    train_list,val_list,test_list,
                    train_seq_path,val_seq_path,test_seq_path)
  



  req_token_path = os.path.join(saved_data_path,'token.pickle')
  req_seq_path = os.path.join(saved_data_path,'sequence.pickle')
  req_train_list_path = os.path.join(saved_data_path,'train_list.pickle')
  req_val_list_path = os.path.join(saved_data_path,'val_list.pickle')
  req_test_list_path = os.path.join(saved_data_path,'test_list.pickle')
  with open(req_token_path, 'wb') as handle:
    pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
  with open(req_seq_path, 'wb') as handle:
    pickle.dump(sequence, handle, protocol=pickle.HIGHEST_PROTOCOL)
  with open(req_train_list_path, 'wb') as handle:
    pickle.dump(train_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
  with open(req_val_list_path, 'wb') as handle:
    pickle.dump(val_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
  with open(req_test_list_path, 'wb') as handle:
    pickle.dump(test_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
  
  
  data_h5_paths = os.path.join(saved_data_path, 'needs.hdf5')
  needs_h5 = h5py.File(data_h5_paths,'w')
  needs_h5.create_dataset('token_path',data=token_path)
  needs_h5.create_dataset('base_directory',data=base_directory)
  needs_h5.create_dataset('train_code_path', data = train_code_path)
  needs_h5.create_dataset('val_code_path', data = val_code_path)
  needs_h5.create_dataset('test_code_path', data = test_code_path)
  needs_h5.create_dataset('train_feature_path',data =train_feature_path)
  needs_h5.create_dataset('val_feature_path', data = val_feature_path)
  needs_h5.create_dataset('test_feature_path', data = test_feature_path)

  needs_h5.create_dataset('train_seq_path', data = train_seq_path)
  needs_h5.create_dataset('val_seq_path', data = val_seq_path)
  needs_h5.create_dataset('test_seq_path', data = test_seq_path)
  
  needs_h5.create_dataset('max_len', data = max_len,dtype='int')
  needs_h5.create_dataset('vocab_size', data = len(tokenizer.word_index)+1,dtype ='int')
  needs_h5.create_dataset('req_token_path',data = req_token_path)
  needs_h5.create_dataset('req_seq_path',data = req_seq_path)
  needs_h5.create_dataset('req_train_list_path',data = req_train_list_path)
  needs_h5.create_dataset('req_val_list_path',data = req_val_list_path)
  needs_h5.create_dataset('req_test_list_path',data = req_test_list_path)  
  
  
  print('end')
  needs_h5.close()


processing0
processing100
processing200
processing300
processing400
processing500
processing600
processing700
processing800
processing900
processing1000
processing1100
processing1200
processing1300
processing1400
processing1500
processing1600
processing1700
processing1800
processing1900
processing2000
processing2100
processing2200
processing2300
processing2400
processing2500
processing2600
processing2700
processing2800
processing2900
processing3000
processing3100
processing3200
processing3300
processing3400
processing3500
processing3600
processing3700
processing3800
processing3900
processing4000
processing4100
processing4200
processing4300
processing4400
processing4500
processing4600
processing4700
processing4800
processing4900
processing5000
processing5100
processing5200
processing5300
processing5400
processing5500
processing5600
processing5700
processing5800
processing5900
complet
processing0
processing100
processing200
processing300
processing400
processing500
processing600
processi