<a href="https://colab.research.google.com/github/LeehyeongTea/image_captioning_with_attention/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import h5py
import os
import tensorflow as tf
from tensorflow.keras.layers import Dense, LSTM, Input, Embedding, Dropout,BatchNormalization,Lambda, Add,Flatten,GRU
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.models import Model

from tensorflow.keras import regularizers,optimizers,losses,metrics
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

from tensorflow.keras.utils import plot_model,to_categorical
import pickle


def get_path(base_directory):
  saved_data_path = os.path.join(base_directory,'data')
  data_h5_paths = os.path.join(saved_data_path, 'needs.hdf5')
  needs = h5py.File(data_h5_paths, 'r')
  train_dataset_list_path = needs['train_code_path'][()]
  val_dataset_list_path = needs['val_code_path'][()]

  train_feature_path = needs['train_feature_path'][()]
  val_feature_path = needs['val_feature_path'][()]

  train_seq_path= needs['train_seq_path'][()]
  val_seq_path= needs['val_seq_path'][()]

  max_len = needs['max_len'][()]
  vocab_size= needs['vocab_size'][()]
  
  req_train_list_path = needs['req_train_list_path'][()]
  req_val_list_path = needs['req_val_list_path'][()]
  req_token_path = needs['req_token_path'][()]
  return train_dataset_list_path, val_dataset_list_path, train_feature_path, val_feature_path, train_seq_path,val_seq_path, max_len,vocab_size, req_train_list_path, req_val_list_path, req_token_path


def get_data(train_feature_path,train_seq_path, val_feature_path,val_seq_path):
  return h5py.File(train_feature_path, 'r'), h5py.File(train_seq_path, 'r'), h5py.File(val_feature_path, 'r'),h5py.File(val_seq_path,'r')

#바다나우 어텐션 적용
class Attention(Model):
  def __init__(self, units, embedding_size):
    super(Attention, self).__init__()
    self.W1 = Dense(units)
    self.W2 = Dense(units)
    self.V = Dense(1)
    self.embedding_size = embedding_size
    self.units = units
  def call(self, values, query):
  
    query = tf.expand_dims(query,axis = 1)

    score = self.V(
        tf.nn.tanh(self.W1(values)+ self.W2(query)))
    
    attention_dist = tf.nn.softmax(score, axis = 1)
    context_vector = attention_dist * values
    
    
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector,attention_dist
  def build(self):
    values = Input(shape = (64,2048))
    query = Input(shape = (self.units,))
    
    return Model(inputs=[values, query],outputs = self.call(values, query))


class Decoder(Model):
  
  def __init__(self, max_len, embedding_size, units, vocab_size, reg):
    super(Decoder,self).__init__()
    self.units = units
    self.embedding_size = embedding_size
    self.embedding_layer = Embedding(vocab_size, embedding_size,mask_zero = True)
    self.lstm = LSTM(units)
    self.max_len = max_len
    self.attention = Attention(units,embedding_size).build()
    self.fc1 = Dense(self.units, activation = 'relu', kernel_regularizer = regularizers.l2(reg))
    self.fc2 = Dense(vocab_size, activation = 'softmax', kernel_regularizer = regularizers.l2(reg))
  
  def call(self, sequence,img, hidden):
    
    context_vector, attention_dist = self.attention([img, hidden])
    sequence = self.embedding_layer(sequence)
    
    output = self.lstm(sequence)

    x = self.fc1(context_vector)

    merge = Add()([output, x])
    merge = self.fc2(merge)
    
    return merge,output,attention_dist

  def build(self):
    sequence = Input(shape = (self.max_len,))
    img = Input(shape = (64, 2048))
    hidden = Input(shape = (self.units))
    return Model(inputs=[sequence, img, hidden],outputs = self.call(sequence,img,hidden))


@tf.function
def train_step(decoder,img, text,max_len,
               tokenizer,optimizer,compile_loss,train_loss,train_acc):
  loss = 0
  hidden = tf.zeros((text.shape[0],512))


  with tf.GradientTape() as tape:
    for i in range(1, text.shape[1]):
      
      input_text = tf.pad(text[:,:i],[[0,0],[0,max_len-i]], "CONSTANT")

      target = text[:,i]

      output,hidden,_ = decoder([input_text, img, hidden])
      g_loss = compile_loss(target,output)
      loss+=g_loss
      train_acc(target,output)

  gradients = tape.gradient(loss,
                            decoder.trainable_variables)

  optimizer.apply_gradients(zip(gradients,decoder.trainable_variables))
  train_loss(loss/text.shape[1])
  

@tf.function
def test_step(decoder,img, text,max_len,
               tokenizer,optimizer,compile_loss,val_loss,val_acc):
  loss = 0
  hidden = tf.zeros((text.shape[0],512))
  
  for i in range(1, text.shape[1]):
    
    input_text = tf.pad(text[:,:i],[[0,0],[0,max_len-i]], "CONSTANT")
    target = text[:,i]
    
    output,hidden,_ = decoder([input_text, img, hidden])
    g_loss = compile_loss(target,output)
    loss+=g_loss
    val_acc(target,output)

  val_loss(loss/text.shape[1])
  


def get_feature_x_y(features,seq,elem):
  f = features[elem][:]
  text = seq[elem][:]
  return f,text  


def get_saved_data(data_list, feature, seq):
  F=list()
  T=list()
  for elem in data_list:
    f,t = get_feature_x_y(feature,seq,elem)
    for i in range(len(t)):
      F.append(f)
      T.append(t[i])
     
  return np.array(F).squeeze(),np.array(T)


if __name__ == "__main__" :
  base_directory = '/content/gdrive/My Drive/Colab Notebooks/image_captioning_with_attention'
  train_dataset_list_path, val_dataset_list_path, train_feature_path, val_feature_path, train_seq_path,val_seq_path, max_len,vocab_size, req_train_list_path, req_val_list_path, req_token_path = get_path(base_directory)


  with open(req_train_list_path, 'rb') as handle:
    train_list = pickle.load(handle)
  with open(req_val_list_path,'rb') as handle:
    val_list = pickle.load(handle)
  with open(req_token_path,'rb') as handle:
    tokenizer = pickle.load(handle)       

  save_path = os.path.join(base_directory,'merge_model','saved_model')

  train_feature, train_seq, val_feature, val_seq= get_data(train_feature_path, train_seq_path,val_feature_path, val_seq_path)


  embedding_size = 256
  units = 512
  reg = 1e-4
  

  decoder = Decoder(max_len, embedding_size, units, vocab_size,reg).build()




  optimizer = optimizers.Adam()
  compile_loss = losses.SparseCategoricalCrossentropy()

  train_loss = metrics.Mean()
  train_acc = metrics.SparseCategoricalAccuracy()
  
  val_loss = metrics.Mean()
  val_acc = metrics.SparseCategoricalAccuracy()
  get_batch_list = list()
  
  
  batch_size = 32

  for epoch in range(0,20):
    get_batch_list = list()

    for i in range(0, len(train_list)):
      get_batch_list.append(train_list[i])
      if i % batch_size == 0 and i != 0 :
        img, text = get_saved_data(get_batch_list, train_feature, train_seq)
        train_step(decoder,img,text,max_len,
                   tokenizer,optimizer,compile_loss,train_loss,train_acc)
        get_batch_list.clear()

    if len(get_batch_list) != 0 :
      img, text = get_saved_data(get_batch_list, train_feature, train_seq)
      train_step(decoder,img,text,max_len,
                 tokenizer,optimizer,compile_loss,train_loss,train_acc)
      get_batch_list.clear()
    
    for i in range(0, len(val_list)):
      get_batch_list.append(val_list[i])
      if i % batch_size == 0 and i != 0 :
        img, text = get_saved_data(get_batch_list, val_feature, val_seq)
        test_step(decoder,img,text,max_len,
                  tokenizer,optimizer,compile_loss,val_loss,val_acc)
        get_batch_list.clear()

    if len(get_batch_list) != 0 :
      img, text = get_saved_data(get_batch_list, val_feature, val_seq,)
      test_step(decoder,img,text,max_len,
                tokenizer,optimizer,compile_loss,val_loss,val_acc)
      get_batch_list.clear()
    print('epoch {0:4d} train acc {1:0.3f} loss {2:0.3f} val acc {3:0.3f} loss {4:0.3f}'.
          format(epoch, train_acc.result(), train_loss.result(), val_acc.result(), val_loss.result()))
    decoder_path = os.path.join(save_path,'decoder_model_{0:02d}_vacc_{1:0.3f}_vloss_{2:0.3f}_acc{3:0.3f}_loss{4:0.3f}.h5'.
                        format(epoch, val_acc.result(), val_loss.result(), train_acc.result(), train_loss.result()))
    

    decoder.save(decoder_path)


  train_feature.close()
  train_seq.close()
  
  val_feature.close()
  val_seq.close()


epoch    0 train acc 0.733 loss 1.710 val acc 0.766 loss 1.332
epoch    1 train acc 0.757 loss 1.441 val acc 0.777 loss 1.230
epoch    2 train acc 0.770 loss 1.296 val acc 0.783 loss 1.172
epoch    3 train acc 0.778 loss 1.198 val acc 0.787 loss 1.137
epoch    4 train acc 0.785 loss 1.124 val acc 0.789 loss 1.115
epoch    5 train acc 0.790 loss 1.066 val acc 0.791 loss 1.100
epoch    6 train acc 0.795 loss 1.019 val acc 0.792 loss 1.092
epoch    7 train acc 0.799 loss 0.978 val acc 0.793 loss 1.087
epoch    8 train acc 0.803 loss 0.943 val acc 0.794 loss 1.082
epoch    9 train acc 0.807 loss 0.911 val acc 0.795 loss 1.079
epoch   10 train acc 0.810 loss 0.883 val acc 0.796 loss 1.076
epoch   11 train acc 0.813 loss 0.857 val acc 0.796 loss 1.075
epoch   12 train acc 0.816 loss 0.833 val acc 0.796 loss 1.075
epoch   13 train acc 0.819 loss 0.812 val acc 0.797 loss 1.076
epoch   14 train acc 0.822 loss 0.792 val acc 0.797 loss 1.077
epoch   15 train acc 0.825 loss 0.774 val acc 0.797 los