In [None]:
!wget -nc https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget -nc https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip -q -o Flickr8k_Dataset.zip
!unzip -q -o Flickr8k_text.zip

In [None]:
import tensorflow as tf
import numpy as np
import os

from pickle import dump
from tensorflow import keras
from keras.models import Model
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from os import listdir

from IPython.display import clear_output

In [None]:
image_shape = (224, 224, 3)

datase_size = 8091
preprocess_batch_size = 3
preprocess_batches = datase_size / preprocess_batch_size

In [None]:
def extract_features(directory):
  # load the model
  model = keras.applications.vgg16.VGG16()
  # re-structure the model
  model = Model(inputs=model.inputs, outputs=model.layers[-2].output)
  # summarize
  # extract features from each photo
  features = dict()

  batch_data = np.zeros((preprocess_batch_size, 224, 224, 3))
  batch_ids = []

  for i, name in enumerate(listdir(directory)):
    # load an image from file
    filename = directory + '/' + name
    image = load_img(filename, target_size=(224, 224))
    # convert the image pixels to a numpy array
    image = img_to_array(image)
    # reshape data for the model
    batch_data[i % preprocess_batch_size] = image
    batch_ids.append(name.split('.')[0])
    if len(batch_ids) == preprocess_batch_size:
      assert(i % preprocess_batch_size == preprocess_batch_size - 1)
      # prepare the image for the VGG model
      print(batch_data.max())
      batch_data = preprocess_input(batch_data)
      print(batch_data.max())
      # get features
      feature = model.predict(batch_data, verbose=0)
      # get image id
      # store feature
      for j, id in enumerate(batch_ids):
        features[id] = feature[j]
      batch_index = ((i+1)//preprocess_batch_size)
      print(f'Batch #{batch_index}/{preprocess_batches} ready')
      batch_ids = []
  return features

In [None]:
features_dict = extract_features('./Flicker8k_Dataset')

In [None]:
def load_doc(filename):
	file = open(filename, 'r')
	text = file.read()
	file.close()
	return text

filename = 'Flickr8k.token.txt'
# load descriptions
doc = load_doc(filename)

In [None]:
# extract descriptions for images
def load_descriptions(doc):
	mapping = dict()
	# process lines
	for line in doc.split('\n'):
		# split line by white space
		tokens = line.split()
		if len(line) < 2:
			continue
		# take the first token as the image id, the rest as the description
		image_id, image_desc = tokens[0], tokens[1:]
		# remove filename from image id
		image_id = image_id.split('.')[0]
		# convert description tokens back to string
		image_desc = ' '.join(image_desc)
		# create the list if needed
		if image_id not in mapping:
			mapping[image_id] = list()
		# store description
		mapping[image_id].append(image_desc)
	return mapping

# parse descriptions
descriptions = load_descriptions(doc)
print('Loaded: %d ' % len(descriptions))

Loaded: 8092 


In [None]:
import string

def clean_descriptions(descriptions):
	# prepare translation table for removing punctuation
	table = str.maketrans('', '', string.punctuation)
	for key, desc_list in descriptions.items():
		for i in range(len(desc_list)):
			desc = desc_list[i]
			# tokenize
			desc = desc.split()
			# convert to lower case
			desc = [word.lower() for word in desc]
			# remove punctuation from each token
			desc = [w.translate(table) for w in desc]
			# remove hanging 's' and 'a'
			desc = [word for word in desc if len(word)>1]
			# remove tokens with numbers in them
			desc = [word for word in desc if word.isalpha()]
			# store as string
			desc_list[i] =  ' '.join(desc)

# clean descriptions
clean_descriptions(descriptions)

In [None]:
# convert the loaded descriptions into a vocabulary of words
def to_vocabulary(descriptions):
	# build a list of all description strings
	all_desc = set()
	for key in descriptions.keys():
		[all_desc.update(d.split()) for d in descriptions[key]]
	return all_desc

# summarize vocabulary
vocabulary = to_vocabulary(descriptions)
print('Vocabulary Size: %d' % len(vocabulary))

Vocabulary Size: 8763


In [None]:
# save descriptions to file, one per line
def save_descriptions(descriptions, filename):
	lines = list()
	for key, desc_list in descriptions.items():
		for desc in desc_list:
			lines.append(key + ' ' + desc)
	data = '\n'.join(lines)
	file = open(filename, 'w')
	file.write(data)
	file.close()

In [None]:
dump(res, open('features.pickle','wb'))
save_descriptions(descriptions, 'descriptions.txt')