<a href="https://colab.research.google.com/github/LeehyeongTea/image_captioning_with_attention/blob/main/model_merge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import h5py
import os
import tensorflow as tf
from tensorflow.keras.layers import Dense, LSTM, Input, Embedding, Dropout,BatchNormalization,Lambda, Add,Flatten,GRU
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.models import Model

from tensorflow.keras import regularizers,optimizers,losses
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

from tensorflow.keras.utils import plot_model,to_categorical
import pickle


def get_path(base_directory):
  saved_data_path = os.path.join(base_directory,'data')
  data_h5_paths = os.path.join(saved_data_path, 'needs.hdf5')
  needs = h5py.File(data_h5_paths, 'r')
  train_dataset_list_path = needs['train_code_path'][()]
  val_dataset_list_path = needs['val_code_path'][()]

  train_feature_path = needs['train_feature_path'][()]
  val_feature_path = needs['val_feature_path'][()]

  train_seq_path= needs['train_seq_path'][()]
  val_seq_path= needs['val_seq_path'][()]

  max_len = needs['max_len'][()]
  vocab_size= needs['vocab_size'][()]
  
  req_train_list_path = needs['req_train_list_path'][()]
  req_val_list_path = needs['req_val_list_path'][()]
  req_token_path = needs['req_token_path'][()]
  return train_dataset_list_path, val_dataset_list_path, train_feature_path, val_feature_path, train_seq_path,val_seq_path, max_len,vocab_size, req_train_list_path, req_val_list_path, req_token_path


def get_data(train_feature_path,train_seq_path, val_feature_path,val_seq_path):
  return h5py.File(train_feature_path, 'r'), h5py.File(train_seq_path, 'r'), h5py.File(val_feature_path, 'r'),h5py.File(val_seq_path,'r')
"""  
Q = Query : t-1 시점의 디코더 셀에서의 은닉 상태 -> hidden
K = Keys : 모든 시점의 인코더 셀의 은닉 상태들 => 
V = Values : 모든 시점의 인코더 셀의 은닉 상태들 => img feature
"""
#바다나우 어텐션 적용
class Attention(Model):
  def __init__(self, units, embedding_size):
    super(Attention, self).__init__()
    self.W1 = Dense(units)
    self.W2 = Dense(units)
    self.V = Dense(1)
    self.embedding_size = embedding_size
    self.units = units
  def call(self, values, query):
    #value = (64, embedding_size)
    #query = (batch_size, hidden_size)

    query = tf.expand_dims(query,axis = 1)

    score = self.V(
        tf.nn.tanh(self.W1(values)+ self.W2(query)))
      

    attention_dist = tf.nn.softmax(score, axis = 1)
    context_vector = attention_dist * values
    
    
    context_vector = tf.reduce_sum(context_vector, axis=1)


    return context_vector,attention_dist
  def build(self):
    values = Input(shape = (64,self.embedding_size))
    query = Input(shape = (self.units,))
    
    return Model(inputs=[values, query],outputs = self.call(values, query))

class Encoder(Model):
  def __init__(self, embedding_dim,reg):
    super(Encoder,self).__init__()

    self.fc = Dense(embedding_dim, activation = 'relu',kernel_regularizer = regularizers.l2(reg))

  def call(self, input):

    img_feature = self.fc(input)

    return img_feature

  def build(self):
    x = Input(shape=(64,2048))

    return Model(inputs=x,outputs=self.call(x))

class Decoder(Model):
  
  def __init__(self, max_len, embedding_size, units, vocab_size, reg):
    super(Decoder,self).__init__()
    self.units = units
    self.embedding_size = embedding_size
    self.embedding_layer = Embedding(vocab_size, embedding_size)
    self.lstm = LSTM(units)

    self.attention = Attention(units,embedding_size).build()
    self.fc1 = Dense(self.units, activation = 'relu', kernel_regularizer = regularizers.l2(reg))
    self.fc2 = Dense(vocab_size, activation = 'softmax', kernel_regularizer = regularizers.l2(reg))
  
  def call(self, sequence,img, hidden):
    
    context_vector, attention_dist = self.attention([img, hidden])
    sequence = self.embedding_layer(sequence)

    output = self.lstm(sequence)

    x = self.fc1(context_vector)

    merge = Add()([output, x])
    merge = self.fc2(merge)
    
    return merge,output,attention_dist

  def build(self):
    sequence = Input(shape = (1,))
    img = Input(shape = (64, self.embedding_size))
    hidden = Input(shape = (self.units))
    return Model(inputs=[sequence, img, hidden],outputs = self.call(sequence,img,hidden))


#text[:,i],output,compile_loss,vocab_size
def get_loss_acc(input, target,compile_loss,vocab_size):
  #loss function에 reduction이 none이기 때문에 reduce_mean을 해주어야함
  #패딩된 데이터를 취급하는 mask
  mask = tf.math.logical_not(tf.math.equal(input,0))
  loss = compile_loss(input,target)
  mask = tf.cast(mask, dtype = loss.dtype)
  loss *=mask
  #loss 결과에 mask->0를 곱해줘 결과를 0으로 만들어준다.
  one_hot = tf.one_hot(input,vocab_size)
  correct = tf.equal(tf.argmax(one_hot,1),tf.argmax(target,1))

  return tf.reduce_mean(loss),tf.reduce_mean(tf.cast(correct, tf.float32))


@tf.function
def train_step(encoder,decoder,img, text,
               tokenizer,optimizer,compile_loss):
  loss = 0
  accuracy = 0

  hidden = tf.zeros((text.shape[0],512))
  text_input = tf.expand_dims([tokenizer.word_index['sq']]*text.shape[0],1)


  with tf.GradientTape() as tape:
    img = encoder(img)
    #max_len만큼
    for i in range(1, text.shape[1]):
      output,hidden,_ = decoder([text_input, img, hidden])
      g_loss ,g_acc= get_loss_acc(text[:,i],output,compile_loss,vocab_size)
      loss+=g_loss
      accuracy+=g_acc


      text_input = tf.expand_dims(text[:,i],1)
      
  all_loss = (loss/int(text.shape[1]))
  all_acc = (accuracy/int(text.shape[1]))
  gradients = tape.gradient(loss,
                            decoder.trainable_variables+encoder.trainable_variables)

  optimizer.apply_gradients(zip(gradients,decoder.trainable_variables+encoder.trainable_variables))
  
  return all_acc,all_loss

@tf.function
def test_step(encoder,decoder,img, text,
               tokenizer,optimizer,compile_loss):
  loss = 0
  accuracy = 0

  hidden = tf.zeros((text.shape[0],512))
  text_input = tf.expand_dims([tokenizer.word_index['sq']]*text.shape[0],1)

  with tf.GradientTape() as tape:
    img = encoder(img)
    #max_len만큼
    for i in range(1, text.shape[1]):
      output,hidden,_ = decoder([text_input, img, hidden])
      g_loss ,g_acc= get_loss_acc(text[:,i],output,compile_loss,vocab_size)
      loss+=g_loss
      accuracy+=g_acc


      text_input = tf.expand_dims(text[:,i],1)
      
  all_loss = (loss/int(text.shape[1]))
  all_acc = (accuracy/int(text.shape[1]))

  return all_acc,all_loss



def get_feature_x_y(features,seq,elem):
  f = features[elem][:]
  text = seq[elem][:]
  return f,text  


def get_saved_data(data_list, feature, seq):
  F=list()
  T=list()
  for elem in data_list:
    f,t = get_feature_x_y(feature,seq,elem)
    for i in range(len(t)):
      F.append(f)
      T.append(t[i])
     
  return np.array(F).squeeze(),np.array(T)


if __name__ == "__main__" :
  base_directory = '/content/gdrive/My Drive/Colab Notebooks/image_captioning_with_attention'
  train_dataset_list_path, val_dataset_list_path, train_feature_path, val_feature_path, train_seq_path,val_seq_path, max_len,vocab_size, req_train_list_path, req_val_list_path, req_token_path = get_path(base_directory)


  with open(req_train_list_path, 'rb') as handle:
    train_list = pickle.load(handle)
  with open(req_val_list_path,'rb') as handle:
    val_list = pickle.load(handle)
  with open(req_token_path,'rb') as handle:
    tokenizer = pickle.load(handle)       
  
  save_path = os.path.join(base_directory,'saved_model_merge')

  train_feature, train_seq, val_feature, val_seq= get_data(train_feature_path, train_seq_path,val_feature_path, val_seq_path)


  embedding_size = 256
  units = 512
  reg = 1e-4
  

  encoder = Encoder(embedding_size,reg).build()
  decoder = Decoder(max_len, embedding_size, units, vocab_size,reg).build()




  optimizer = optimizers.Adam()
  compile_loss = losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

  
  get_batch_list = list()
  batch_size = 100
  train_loss = []
  train_acc = []

  val_loss = []
  val_acc = []
  for epoch in range(0,20):
    get_batch_list = list()
    batch_train_loss=[]
    batch_train_acc=[]

    for i in range(0, len(train_list)):
      get_batch_list.append(train_list[i])
      if i % batch_size == 0 and i != 0 :
        img, text = get_saved_data(get_batch_list, train_feature, train_seq)
        t_acc,t_loss = train_step(encoder,decoder,img,text,
                                       tokenizer,optimizer,compile_loss)
        batch_train_loss.append(t_loss)
        batch_train_acc.append(t_acc)
        get_batch_list.clear()

    if len(get_batch_list) != 0 :
      img, text = get_saved_data(get_batch_list, train_feature, train_seq)
      t_acc,t_loss = train_step(encoder,decoder,img,text,
                                       tokenizer,optimizer,compile_loss)
      batch_train_loss.append(t_loss)
      batch_train_acc.append(t_acc)
      get_batch_list.clear()
    train_loss.append(np.mean(batch_train_loss))
    train_acc.append(np.mean(batch_train_acc))

    batch_val_loss=[]
    batch_val_acc = []
    for i in range(0, len(val_list)):
      get_batch_list.append(val_list[i])
      if i % batch_size == 0 and i != 0 :
        img, text = get_saved_data(get_batch_list, val_feature, val_seq)
        v_acc,v_loss = test_step(encoder,decoder,img,text,
                                       tokenizer,optimizer,compile_loss)
        batch_val_loss.append(v_loss)
        batch_val_acc.append(v_acc)
        get_batch_list.clear()

    if len(get_batch_list) != 0 :
      img, text = get_saved_data(get_batch_list, val_feature, val_seq)
      v_acc,v_loss = test_step(encoder,decoder,img,text,
                                       tokenizer,optimizer,compile_loss)
      batch_val_loss.append(v_loss)
      batch_val_acc.append(v_acc)

    val_loss.append(np.mean(batch_val_loss))
    val_acc.append(np.mean(batch_val_acc))
    print('epoch {0:4d} train acc {1:0.3f} loss {2:0.3f} val acc {3:0.3f} loss {4:0.3f}'.
          format(epoch, np.mean(train_acc), np.mean(train_loss), np.mean(val_acc), np.mean(val_loss)))
    encoder_path = os.path.join(save_path,'encoder_model_{0:02d}_vacc_{1:0.3f}_vloss_{2:0.3f}_acc{3:0.3f}_loss{4:0.3f}.h5'.
                        format(epoch, np.mean(val_acc), np.mean(val_loss), np.mean(train_acc), np.mean(train_loss)))
    decoder_path = os.path.join(save_path,'decoder_model_{0:02d}_vacc_{1:0.3f}_vloss_{2:0.3f}_acc{3:0.3f}_loss{4:0.3f}.h5'.
                        format(epoch, np.mean(val_acc), np.mean(val_loss), np.mean(train_acc), np.mean(train_loss)))
    

    encoder.save(encoder_path)
    decoder.save(decoder_path)


  train_feature.close()
  train_seq.close()
  
  val_feature.close()
  val_seq.close()


epoch    0 train acc 0.049 loss 1.707 val acc 0.066 loss 1.527
epoch    1 train acc 0.061 loss 1.548 val acc 0.072 loss 1.434
epoch    2 train acc 0.070 loss 1.440 val acc 0.080 loss 1.365
epoch    3 train acc 0.078 loss 1.357 val acc 0.086 loss 1.316
epoch    4 train acc 0.083 loss 1.292 val acc 0.089 loss 1.282
epoch    5 train acc 0.088 loss 1.241 val acc 0.092 loss 1.257
epoch    6 train acc 0.091 loss 1.199 val acc 0.094 loss 1.238
epoch    7 train acc 0.094 loss 1.163 val acc 0.096 loss 1.225
epoch    8 train acc 0.097 loss 1.133 val acc 0.097 loss 1.214
epoch    9 train acc 0.099 loss 1.106 val acc 0.098 loss 1.206
epoch   10 train acc 0.101 loss 1.081 val acc 0.099 loss 1.199
epoch   11 train acc 0.103 loss 1.059 val acc 0.099 loss 1.194
epoch   12 train acc 0.104 loss 1.039 val acc 0.100 loss 1.190
epoch   13 train acc 0.106 loss 1.021 val acc 0.101 loss 1.187
epoch   14 train acc 0.108 loss 1.004 val acc 0.101 loss 1.184
epoch   15 train acc 0.109 loss 0.989 val acc 0.102 los