# image captioning using attention

Based on Tensorflow tutorial [Image captioning with visual attention](https://www.tensorflow.org/text/tutorials/image_captioning)

Similar to architecture described in  [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044)

Features are extracted from the image and passed to cross attention layers of the xfmr decoder

With modifications described along the way...this does not currently leverage any available GPU, but will by the time we're done...

## Step 0 - Environment Setup 

packages available in the yaml provided are sufficient for this work
conda yaml for the environment compatible with all of these notebooks is included in the root of this repo 

In [2]:
import concurrent.futures
import collections
import dataclasses
import hashlib
import itertools
import json
import math
import os
import pathlib
import random
import re
import string
import time
import urllib.request

import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import requests
import tqdm

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import tensorflow_datasets as tfds

## Step 1 data handling

download and prep dataset for training
tokenizes input text, caches results of running images through a pretrained feature extractor

In [3]:
local_data_path_root = "C:/LocalResearch/JPD-Research/translationWork"
local_data_path = local_data_path_root+ "/data"
local_vocab_path = local_data_path_root+ "/vocab/"
local_model_path = local_data_path_root+ "/models/"

# Directory where the checkpoints will be saved
checkpoint_dir = local_data_path_root + '/training_checkpoints'

In [21]:
def flickr8k(path='flickr8k'):
  path = pathlib.Path(path)
  print(path)
  print(local_data_path)
  fullpath = local_data_path/path

  if len(list(path.rglob('*'))) < 16197:
    tf.keras.utils.get_file(
        origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',
        cache_dir=local_data_path,
        cache_subdir=path,
        extract=True)
    tf.keras.utils.get_file(
        origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',
        cache_dir=local_data_path,
        cache_subdir=path,
        extract=True)

  
  print(fullpath/"Flickr8k.token.txt")
  captions = (fullpath/"Flickr8k.token.txt").read_text().splitlines()
  captions = (line.split('\t') for line in captions)
  captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)

  cap_dict = collections.defaultdict(list)
  for fname, cap in captions:
    cap_dict[fname].append(cap)

  train_files = (fullpath/'Flickr_8k.trainImages.txt').read_text().splitlines()
  train_captions = [(str(fullpath/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]

  test_files = (fullpath/'Flickr_8k.testImages.txt').read_text().splitlines()
  test_captions = [(str(fullpath/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]

  train_ds = tf.data.experimental.from_list(train_captions)
  test_ds = tf.data.experimental.from_list(test_captions)

  return train_ds, test_ds

In [36]:
def conceptual_captions(*, data_dir="conceptual_captions", num_train, num_val):
  def iter_index(index_path):
    with open(index_path) as f:
      for line in f:
        caption, url = line.strip().split('\t')
        yield caption, url

  def download_image_urls(data_dir, urls):
    ex = concurrent.futures.ThreadPoolExecutor(max_workers=100)
    def save_image(url):
      hash = hashlib.sha1(url.encode())
      # Name the files after the hash of the URL.
      file_path = data_dir/f'{hash.hexdigest()}.jpeg'
      if file_path.exists():
        # Only download each file once.
        return file_path

      try:
        result = requests.get(url, timeout=5)
      except Exception:
        file_path = None
      else:
        file_path.write_bytes(result.content)
      return file_path

    result = []
    out_paths = ex.map(save_image, urls)
    for file_path in tqdm.tqdm(out_paths, total=len(urls)):
      result.append(file_path)

    return result

  def ds_from_index_file(index_path, data_dir, count):
    data_dir.mkdir(exist_ok=True)
    index = list(itertools.islice(iter_index(index_path), count))
    captions = [caption for caption, url in index]
    urls = [url for caption, url in index]

    paths = download_image_urls(data_dir, urls)

    new_captions = []
    new_paths = []
    for cap, path in zip(captions, paths):
      if path is None:
        # Download failed, so skip this pair.
        continue
      new_captions.append(cap)
      new_paths.append(path)

    new_paths = [str(p) for p in new_paths]

    ds = tf.data.Dataset.from_tensor_slices((new_paths, new_captions))
    ds = ds.map(lambda path,cap: (path, cap[tf.newaxis])) # 1 caption per image
    return ds

  #data_dir = pathlib.Path(data_dir)
  train_index_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',
    cache_dir=local_data_path,
    cache_subdir=data_dir
    )

  val_index_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',
    cache_dir=local_data_path,
    cache_subdir=data_dir
    )

  fullpath = local_data_path+"/"+data_dir
  train_raw = ds_from_index_file(train_index_path, data_dir=pathlib.Path(fullpath+'/train'), count=num_train)
  test_raw = ds_from_index_file(val_index_path, data_dir=pathlib.Path(fullpath+'/val'), count=num_val)

  return train_raw, test_raw

In [43]:
choose = 'flickr8k'

if choose == 'flickr8k':
  train_raw, test_raw = flickr8k()
else:
  train_raw, test_raw = conceptual_captions(num_train=10000, num_val=5000)

flickr8k
C:/LocalResearch/JPD-Research/translationWork/data
C:\LocalResearch\JPD-Research\translationWork\data\flickr8k\Flickr8k.token.txt


In [44]:
train_raw.element_spec

(TensorSpec(shape=(), dtype=tf.string, name=None),
 TensorSpec(shape=(5,), dtype=tf.string, name=None))

In [45]:
for ex_path, ex_captions in train_raw.take(1):
  print(ex_path)
  print(ex_captions)

tf.Tensor(b'C:\\LocalResearch\\JPD-Research\\translationWork\\data\\flickr8k\\Flicker8k_Dataset\\2513260012_03d33305cf.jpg', shape=(), dtype=string)
tf.Tensor(
[b'A black dog is running after a white dog in the snow .'
 b'Black dog chasing brown dog through snow'
 b'Two dogs chase each other across the snowy ground .'
 b'Two dogs play together in the snow .'
 b'Two dogs running through a low lying body of water .'], shape=(5,), dtype=string)
