<a href="https://colab.research.google.com/github/amifunny/Deep-Learning-Notebook/blob/master/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
print( tf.__version__ )

In [0]:
"""
  Image Captioning - Based on End to End Modeling

  See Paper - 'https://arxiv.org/pdf/1411.4555.pdf'

  Encoder - PreTrained Imagenet ResNet50 Model
  Decoder - LSTM with Dimension 512

  Sampling Via Beam Search

"""

In [0]:
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

import os
import json
import re

from collections import Counter
import random

In [0]:
!wget 'http://images.cocodataset.org/zips/train2014.zip'

In [0]:
%%capture
!unzip 'train2014.zip'

In [0]:
!wget 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'

In [0]:
!unzip 'annotations_trainval2014.zip'

In [0]:
enc_dim = 512
hidden_dim = 512

batch_size = 32
img_size = 224
vocab_size = 10000

embed_dim = 100

In [0]:
# this is to find vocab_size
# Again Reading the file redundant , but lets do it
annotation_file = '/content/annotations/captions_train2014.json'
with open(annotation_file) as file:

  content = file.read()
  file.close()

annotate_dict = json.loads(content)

all_words = []

for idx in range( len(annotate_dict['annotations']) ):

  img_caption = annotate_dict['annotations'][idx]['caption']

  cleaned_caps = re.sub( r'[^a-zA-Z0-9. ]' , '' , img_caption )
  cleaned_caps = re.sub( r'[.]' , ' .', cleaned_caps )

  capt_words = cleaned_caps.lower().split()
  all_words.extend( capt_words )

In [0]:
word_to_int = { '<pad>':0,'<start>':1,'<end>':2,'<unk>':3 }
int_to_word = { 0:'<pad>',1:'<start>',2:'<end>',3:'<unk>' }
before_size = len( word_to_int )

occur_list = Counter(all_words)
unique_list = occur_list.most_common()
vocab_words =  unique_list[:vocab_size-len(word_to_int)]

for i in range( len(vocab_words) ):

  word_to_int[ vocab_words[i][0] ] = i+before_size
  int_to_word[ i+before_size ] = vocab_words[i][0]

print( len(word_to_int) )
print( len(int_to_word) )
print( int_to_word )


In [0]:
def get_shuffled_annots(start=0,limited_num=1024):
  
  annotation_file = '/content/annotations/captions_train2014.json'
  with open(annotation_file) as file:

    content = file.read()
    file.close()

  annotate_dict = json.loads(content)
  
  content_list = annotate_dict['annotations'][start:start+limited_num]

  # shuffle image annotions and hence the images
  random.shuffle( content_list ) 

  return content_list

def get_data( content_list ,start=0,noi=100):

  images = []
  images_caption = []

  for idx in range( start , start + len(content_list[start:start+noi]) ):

    img_id = content_list[idx]['image_id']
    img_caption = content_list[idx]['caption']
    img_name = 'COCO_train2014_'+"{:012d}".format(img_id)+".jpg"

    img = Image.open( '/content/train2014/'+img_name )
    img_arr = np.array( img.resize( [img_size,img_size] ) )/127.5
    img_arr = img_arr-1.0
    if img_arr.shape[-1]!=3:
      img_arr = np.repeat( img_arr[:, :, np.newaxis], 3, axis=2)


    cleaned_caps = re.sub( r'[^a-zA-Z0-9. ]' , '' , img_caption )
    cleaned_caps = re.sub( r'[.]' , ' .', cleaned_caps )

    splited_caps = [int_to_word[1]] + cleaned_caps.lower().split() + [int_to_word[2]]

    images.append( img_arr )
    images_caption.append( splited_caps )

  return images,images_caption

In [0]:
annot_list = get_shuffled_annots()
# *******
total_examples = len( annot_list )
print( "Total Number of Examples ==>  {}".format(total_examples) )
# *******


images,images_caption = get_data( annot_list , start=0,noi=256)
print( len(images) )
print( len(images_caption) )

for i in range(3):
  print(images_caption[i])
  print(images[i].shape)
  show = plt.imshow( images[i] )
  plt.show()


In [0]:
def convert_to_int(word_captions):

  int_cap = []

  for each_cap in word_captions:

    int_each_cap=[]
    for w in each_cap:

      int_each_cap.append( word_to_int.get( w , word_to_int['<unk>'] ) )

    int_cap.append(int_each_cap)

  return int_cap  

In [0]:
def convert_to_word(int_captions):

  word_cap = []

  for each_cap in int_captions:

    word_each_cap=[]
    for i in each_cap:

      word_each_cap.append( int_to_word[i] )

    word_cap.append(word_each_cap)

  return word_cap  

In [0]:
   
img_batches = []
cap_batches = []
@tf.function(experimental_relax_shapes=True)
def prep_batch(imgs,caps,batch_size):

    num_of_batches = int(len(imgs)/batch_size)
    
    img_batches = []
    cap_batches = []
    
    for i in range(num_of_batches):

      img_batch = tf.convert_to_tensor( imgs[i*batch_size:(i+1)*batch_size] , tf.float32 )
      img_batches.append( img_batch )

      padded_batch = tf.keras.preprocessing.sequence.pad_sequences( caps[i*batch_size:(i+1)*batch_size] ,padding='post')
      cap_batch = tf.convert_to_tensor(padded_batch)
      cap_batches.append( cap_batch )

    return img_batches,cap_batches,num_of_batches

In [0]:
int_caps = convert_to_int( images_caption )
img_batches,cap_batches,num_of_batches = prep_batch( images , int_caps , 32 )

print( img_batches[0].shape )
print( cap_batches[0].shape )
print( num_of_batches )

for batch in cap_batches:
  # int caption batch
  print( batch )
  break

In [0]:
resnet = tf.keras.applications.InceptionV3( include_top=False , input_shape=[img_size,img_size,3] )
resnet.trainable = True

inputs = tf.keras.layers.Input( shape=[img_size,img_size,3] )
out = resnet( inputs )
out = tf.keras.layers.GlobalAveragePooling2D()(out)
out = tf.keras.layers.Dense( hidden_dim*2 ,activation='relu')(out)
outputs = tf.keras.layers.Dense( hidden_dim ,activation='relu' )(out)

encoder_model = tf.keras.Model( inputs , outputs )
encoder_model.summary()

In [0]:
"""
  Single Stacked LSTM

  hidden_state of LSTM same as output of encoder

"""

inputs = tf.keras.layers.Input([None])
hidden_state_in = tf.keras.layers.Input([enc_dim])

embed_out = tf.keras.layers.Embedding( vocab_size , embed_dim )(inputs)

out,hidden_state_out = tf.keras.layers.GRU( hidden_dim , return_state=True ,return_sequences=True )(embed_out,initial_state=[hidden_state_in])
out = tf.keras.layers.Dense( hidden_dim*2 )(out)
outputs = tf.keras.layers.Dense( vocab_size , activation='softmax' )(out)

decoder_model = 0
decoder_model = tf.keras.Model( [inputs,hidden_state_in] , [outputs,hidden_state_out] )
decoder_model.summary()

In [0]:
#********************
batch_size = 32
l_rate = 0.0001
ctr=0
epochs=10
subset_size = 2048
step_limit = 25
#********************

mean_loss = tf.keras.metrics.Mean()
optimizer = tf.keras.optimizers.RMSprop( l_rate )
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

with tf.device('/device:GPU:0'):

  for e in range(epochs):

    mean_loss.reset_states()
    
    print("********* EPOCH :: {}".format(e))

    annot_list = get_shuffled_annots( 0 ,subset_size*step_limit )
    offset = 0
    steps = 0

    if e%2==0 and e!=0:
      l_rate = l_rate/5.0
      optimizer.learning_rate = l_rate

    while steps<step_limit:
      
      steps = steps + 1
      images,images_caption = get_data( annot_list , offset , subset_size )
      int_captions = convert_to_int( images_caption )
      batches_img,batches_cap,nob = prep_batch(images,int_captions,batch_size)

      offset = offset + subset_size

      for i in range(nob):
        
        x_batch = batches_img[i]
        y_batch = batches_cap[i]

        with tf.GradientTape(persistent=True) as tape:

          encoding_vec = encoder_model( x_batch , training=True)

          hidden_state = encoding_vec

          cost = 0
          dec_input = tf.expand_dims( tf.convert_to_tensor( [ word_to_int['<start>'] ]*batch_size ) , -1 )
          for t in range(1,y_batch.shape[1]):
            pred_output,hidden_state = decoder_model( [dec_input , hidden_state] )
            true_output = tf.expand_dims( y_batch[:,t] , -1 )

            # We use teacher forcing , and give correct label for each timestep as input
            dec_input = true_output

            cost = cost + loss_fn( true_output , tf.squeeze(pred_output) )

        
        # You can avg. cost , but we don't as it will slow the learning
        # avg_cost = cost/y_batch.shape[1]
        mean_loss.update_state( cost )

        trainable_variables = encoder_model.trainable_variables + decoder_model.trainable_variables
        grads = tape.gradient( cost,trainable_variables )
        optimizer.apply_gradients(zip(grads,trainable_variables))

      print("Subset Loss is ==== > {} ".format(mean_loss.result()))

    encoder_model.save('enc.h5')
    decoder_model.save('dec.h5')  
    print("Epoch Loss is ==== > {} ".format(mean_loss.result()))


"""
  PS : Ignore the WARNING : it is bcz of varing time steps in input , which are expensive to parallelise.
"""

In [0]:
encoder_model.save('enc.h5')
decoder_model.save('dec.h5')

In [0]:

#  taking a random range of images not encountered by model
annot_list = get_shuffled_annots(100000,110000)
# taking only 32 images data
images,images_caption = get_data( annot_list , 0 , 32 )
int_captions = convert_to_int( images_caption )
batches_img,batches_cap,nob = prep_batch(images,int_captions,batch_size)

x_batch = batches_img[0]
y_batch = batches_cap[0]

encoding_vec = encoder_model( x_batch , training=True)

hidden_state = encoding_vec

dec_input = tf.expand_dims( tf.convert_to_tensor( [ word_to_int['<start>'] ]*batch_size ) , -1 )
sample_out = tf.expand_dims( tf.convert_to_tensor( [ word_to_int['<start>'] ]*batch_size ) , -1 )

for t in range(1,y_batch.shape[1]):
  pred_output,hidden_state = decoder_model( [dec_input , hidden_state] )

  # We use Greedy sampling ie take most probable word
  max_prob_out = tf.expand_dims( tf.squeeze( tf.cast( tf.argmax( pred_output , -1 ) , tf.int32 ) ) , -1 )
  dec_input = max_prob_out

  # concat best prediction of each timestep
  sample_out = tf.concat([sample_out,max_prob_out],-1)


print( sample_out.shape )


In [0]:
np_batch = x_batch.numpy()
wcaps = convert_to_word( sample_out.numpy() )
print( wcaps )

for b in range(batch_size):
  print( wcaps[b] )
  show = plt.imshow((np_batch[b]+1.0)/2.0)
  plt.show()

""" 
  You can see even with only 50k samples subset we get interesting results.
  Even if caption far from correct , it catches few nuances.
  Feel Free to Train on Whole set of 400k images
"""  
  