<a href="https://colab.research.google.com/github/DrishtiSabhaya/ImageCaptionGenerator/blob/master/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import string

def load_description(text):
	mapping = dict()
	for line in text.split('\n'):
		token = line.split()
		if len(line)<2:
			continue
		img_id, img_desc = token[0], token[1:]
		img_id = img_id.split('.')[0]
		img_desc = ' '.join(img_desc)
		if img_id not in mapping:                            
			mapping[img_id] = list()
		mapping[img_id].append(img_desc)
	return mapping

def clean_text(descriptions):
	table = str.maketrans('', '', string.punctuation)      # translation table for removing all types of punctuation
	for key, desc_list in descriptions.items():
		for i in range(len(desc_list)):
			desc = desc_list[i]
			desc = desc.split()
			desc = [word.lower() for word in desc]
			desc = [w.translate(table) for w in desc]
			desc = [word for word in desc if len(word)>1]
			desc = [word for word in desc if word.isalpha()]
			desc_list[i] = ' '.join(desc)
			
def save_desc(descriptions, file2):
	lines = []
	for key, desc_list in descriptions.items():
		for desc in desc_list:
			lines.append(key+' '+desc)
	data = '\n'.join(lines)
	f = open(file2,'w')
	f.write(data)
	f.close()
	

	
file1 = open('Flickr8k.token.txt','r')
text = file1.read()

descriptions = load_description(text)
clean_text(descriptions)
save_desc(descriptions,'desc.txt')


In [2]:
import pickle
from os import listdir
from pickle import dump
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.models import Model

def extract_features(directory):
	model = VGG16()
	model = Model(inputs=model.inputs, outputs=model.layers[-2].output)
	print(model.summary())
	
	features = dict()
	for name in listdir(directory):
		file1 = directory +'/'+ name
		image = load_img(file1, target_size=(224,224))
		image = img_to_array(image)
		image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
		image = preprocess_input(image)
		feature = model.predict(image, verbose=0)
		img_id = name.split('.')[0]
		features[img_id] = feature
	return features

directory = '/content/drive/My Drive/Datasets/Images/Flicker8k_Dataset'
features = extract_features(directory)
print(len(features))

pickle.dump(features,open('features.pkl','wb'))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)   

In [3]:
from numpy import array
from pickle import load
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Embedding
from keras.layers import Dropout
from keras.layers.merge import add
from keras.callbacks import ModelCheckpoint

def load_set(filename):
	file1 = open(filename,'r')
	text = file1.read()
	file1.close()
	dataset = []
	for line in text.split('\n'):
		if len(line)<1:
			continue
		identifier = line.split('.')[0]
		dataset.append(identifier)
	return set(dataset)
	
def load_desc(filename,dataset):
	file2 = open(filename,'r')
	doc = file2.read()
	descriptions = dict()
	for line in doc.split('\n'):
		tokens = line.split()
		img_id, img_desc = tokens[0], tokens[1:]
		if img_id in dataset:
			if img_id not in descriptions:
				descriptions[img_id] = list()
			desc = 'start' + ' '.join(img_desc) + 'end'
			descriptions[img_id].append(desc)
	return descriptions
	
def load_features(filename, dataset):
	# load all features
	l = []
	all_features = load(open(filename, 'rb'))
	# filter features
	features = {k: all_features[k] for k in dataset}
	return features
	
def to_line(descriptions):
	desc = []
	for key in descriptions.keys():
		[desc.append(value) for value in descriptions[key]]
	return desc
	
def tokenize(descriptions):
	lines = to_line(descriptions)
	tokenizer = Tokenizer()
	tokenizer.fit_on_texts(lines)
	return tokenizer
	
def max_length(descriptions):	
	lines = to_line(descriptions)
	return max(len(d.split())for d in lines)
	
def sequences(tokenizer, max_length, desc_list, photos, vocab_size):
	X1, X2, y = [], [], []
	for desc in desc_list:
		seq = tokenizer.texts_to_sequences([desc])[0]                          #encode the sequence
		for i in range(1,len(seq)):
			in_seq, out_seq = seq[:i], seq[i]                                  # split into input and output pair
			in_seq = pad_sequences([in_seq], maxlen = max_length)[0]            #pad sequences to match length of every sequence
			out_seq = to_categorical([out_seq], num_classes = vocab_size)[0]
				
			X1.append(photos)
			X2.append(in_seq)
			y.append(out_seq)
				
	return array(X1), array(X2), array(y)
	
def def_model(vocab_size, max_length):
	inputs1 = Input(shape=(4096,))
	fe1 = Dropout(0.5)(inputs1)
	fe2 = Dense(256, activation = 'relu')(fe1)
	
	inputs2 = Input(shape=(max_length,))
	se1 = Embedding(vocab_size, 256, mask_zero = True)(inputs2)
	se2 = Dropout(0.5)(se1)
	se3 = LSTM(256)(se2)
	
	decoder1 = add([fe2, se3])
	decoder2 = Dense(256, activation = 'relu')(decoder1)
	outputs = Dense(vocab_size, activation = 'softmax')(decoder2)
	
	model = Model(inputs = [inputs1, inputs2], outputs = outputs)
	model.compile(loss = 'categorical_crossentropy', optimizer = 'adam')
	model.summary()
	plot_model(model, to_file = 'model.png', show_shapes = True)
	return model
	
def data_generator(descriptions, photos, tokenizer, max_length, vocab_size):
	while 1:
		for key, desc_list in descriptions.items():
			photo = photos[key][0]
			in_img, in_seq, out_word = sequences(tokenizer, max_length, desc_list, photo, vocab_size)
			yield [in_img, in_seq], out_word
					
				
filename = 'Flickr_8k.trainImages.txt'
train = load_set(filename)
train_desc = load_desc('desc.txt',train)
train_features = load_features('features.pkl',train)

tokenizer = tokenize(train_desc)
vocab_size = len(tokenizer.word_index)+1
max_length = max_length(train_desc)

model = def_model(vocab_size, max_length)
epochs = 20
steps = len(train_desc)
for i in range(epochs):
	generator = data_generator(train_desc, train_features, tokenizer, max_length, vocab_size)
	model.fit_generator(generator, epochs = 1, steps_per_epoch = steps, verbose = 2)
	model.save('model_'+str(i)+'.h5')

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 32)]         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 4096)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 32, 256)      2662400     input_3[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 4096)         0           input_2[0][0]                    
_______________________________________________________________________________________