# Preprocessing

In [None]:
import os
import numpy as np
import h5py
import json
import torch

from PIL import Image

from collections import Counter
from random import seed, choice, sample

from os import listdir
from tqdm import tqdm
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from pickle import dump


import shutil

In the cell below one chooses th arguments for the preprocessing of data before training the model. The `dataset` is a folder with `captions` and `img` directories, where one could find images in `.jpg` format and captions for every image in `.txt` format. The images and captions are linked through `image_id`.

In [None]:
max_cap_lenght = 60
cap_per_img = 5
min_word_freq = 10

#PAY ATTENTION HERE WHEN RERUNNING
# dataset = 'dataset_to_train_on'
# folder = 'where_output_will_go'

dataset = 'celeb_small_res'
folder = 'trial'

#Choose only one or none as TRUE
full_anton_captions = False
partial_anton_captions = False
mix_data = True
mix_captions = False
#-----------------

base_filename = f'{dataset}_{cap_per_img}_{min_word_freq}'
main_path = '/Users/evelsve/repos/cap'
out_path = main_path+ f'/preprocessing_outputs/{folder}'
combine_with_sketches = False

The cell below define functions to rename files, in this case, the *Distorted* images. (in-house we called them freakshow)

In [None]:
def strip_flickr(f_name):
    file_id = f_name.strip('.jpg').split()[0]
    return file_id


def strip_freakshow(f_name):
    file_id = f_name.strip('.jpg').split('_')[2]
    return file_id


def do_files_rename(path_to_data, dataset_from, folder_to):
    for f_name in listdir(f'{path_to_data}/{dataset_from}'):
        if dataset_from == 'flickr8k':
            file_id = strip_flickr(f_name)
        else: 
            file_id = strip_freakshow(f_name)
        original = f'{path_to_data}/{dataset_from}/img/{f_name}'
        target = f'{path_to_data}/{folder_to}/img/{file_id}.jpg'
        shutil.copyfile(original, target)
        
def rename_and_create_flickr(main_path, path_to_old, path_to_new):
    flick = dict()

    flick_map, new_flick = dict(), dict()

    with open(f'{path_to_old}/captions.txt', mode='r') as f:
        lines = f.readlines()
        for line in lines[1:]:
            line = line.strip().split('.jpg,')
            image_id, caption = line[0], line[1]
            if image_id in flick:
                flick[image_id].append(caption)
            else:
                flick[image_id] = list()
                
    for i, key in enumerate(flick):
        flick_map[key] = i
        new_flick[i] = flick[key]
        with open(f'{path_to_new}/{i}.txt', mode='w') as f:
            for item in flick[key]:
                print(item, file=f)

    do_files_rename(f'{main_path}/data', 'flickr8k', 'flickr')
    

The cell below defines:
- reading from dataset into a dictionary
- creating word maps for encoding
- creating `.hdf5` files, whcih are later used for training


There is some leftover code (commented) for joining images and sketches into one embedding, yet the model was not adapted to deal with 6 channels instead of 3 -- future work.

The logic for everything below is similar to that in `a-PyTorch-Tutorial-to-Image-Captioning`, yet we adapted it in such a way that there is no need for karpahty files or similar, just plain dictionaries of paths and captions. 

In [None]:
def load_doc(filename):
    # load doc into memory
    file = open(filename, 'r')
    text = file.read()
    file.close()
    return text


def read_create_dataset(directory, augment=False):
    # two outputs -- path:caption and id:path dictionaries
    paths_caption_dictn, id_path_dictn = dict(), dict()
    path_to_data = f'{main_path}/data/{dataset}'
    for f_name in tqdm(listdir(f'{path_to_data}/img')):
        img_filepath = f'{path_to_images}/{f_name}'
        jpgname = f_name.split('.')
        if jpgname[1] == 'jpg':
            image_id = int(jpgname[0])
            if augment:
                cap_filepath = f'{main_path}/data/augmented_captions/{image_id}.txt'
            else:
                cap_filepath = f'{path_to_data}/captions/{image_id}.txt'
            id_path_dictn[image_id] = img_filepath
            lines = load_doc(cap_filepath)
            lines = lines.split('\n')
            paths_caption_dictn[img_filepath] = list()
            for caption in lines[:10]:
                caption = word_tokenize(caption)
                caption =[token.lower() for token in caption if token != ',']
                paths_caption_dictn[img_filepath].append(caption)
        else: 
            print(f'Encountered random file: {f_name}')
    return paths_caption_dictn, id_path_dictn



def mixing_read_create_dataset(main_path):
     # two outputs -- path:caption and id:path dictionaries
    paths_caption_dictn, id_path_dictn = dict(), dict()
    path_to_data = f'{main_path}/data/{dataset}'
    path_to_flickr = f'{main_path}/data/flickr'
    for f_name in tqdm(listdir(f'{path_to_data}/img')):
        img_filepath = f'{path_to_data}/img/{f_name}'
        jpgname = f_name.split('.')
        if jpgname[1] != 'jpg':
            print(f'Encountered random file: {f_name}')
        else:
            image_id = int(jpgname[0])
            if mix_data and 7000 <= image_id < 8090:
                img_filepath = f'{path_to_flickr}/img/{image_id}.jpg'
                cap_filepath = f'{path_to_flickr}/captions/{image_id}.txt'
            else:
                cap_filepath = f'{path_to_data}/captions/{image_id}.txt'
            id_path_dictn[image_id] = img_filepath
            #load captions
            lines = load_doc(cap_filepath)
            lines = lines.split('\n')
            # add path + list to dictn
            paths_caption_dictn[img_filepath] = list()
            #take only 10 descriptions
            for caption in lines[:10]:
                if len(caption) > 3:
                    # add a tokenized caption to path_captions dict
                    caption = word_tokenize(caption)
                    #caption = nltk.Text(caption)
                    caption =[token.lower() for token in caption if token != ',']
                    paths_caption_dictn[img_filepath].append(caption)

    return paths_caption_dictn, id_path_dictn



def partial_augment_read_create_dataset():
    paths_caption_dictn, id_path_dictn = dict(), dict()
    path_to_data = f'{main_path}/data/{dataset}'
    for f_name in tqdm(listdir(f'{path_to_data}/img')):
        img_filepath = f'{path_to_images}/{f_name}'
        jpgname = f_name.split('.')
        if jpgname[1] == 'jpg':
            image_id = int(jpgname[0])
            
            alt_cap_filepath = f'{main_path}/data/augmented_captions/{image_id}.txt'
            cap_filepath = f'{path_to_data}/captions/{image_id}.txt'
            
            id_path_dictn[image_id] = img_filepath
            
            lines = load_doc(cap_filepath)
            lines = lines.split('\n')
            
            alt_lines = load_doc(alt_cap_filepath)
            alt_lines = alt_lines.split('\n')
            
            paths_caption_dictn[img_filepath] = list()
            
            for i, caption in enumerate(lines[:10]):
                if 3 < i < 7:
                    caption = alt_lines[i]
                caption = word_tokenize(caption)
                caption =[token.lower() for token in caption if token != ',']
                paths_caption_dictn[img_filepath].append(caption)
                
        else: 
            print(f'Encountered random file: {f_name}')
    return paths_caption_dictn, id_path_dictn



def dictn_reverser(dictn):
    val_key = {v:k for k, v in dictn.items()}
    return val_key

def get_dict_from_file(path):
    with open(path, "r") as f:
        dictn = json.load(f)
    return dictn

def create_word_map(data, max_cap_lenght):
    word_freq, captions = dict(), list()
    for path in data:
        # read throuh every tokenized and lowrcased caption as a list of tokens
        for captn in data[path]:
            if len(captn) <= max_cap_lenght:
                pass
            else:
                captn = captn[:max_cap_lenght]
            for w in captn:
                captions.append(w)
                if w in word_freq:
                    word_freq[w] += 1
                else:
                    word_freq[w] = 1
    
    words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
    word_map = {k: v + 1 for v, k in enumerate(words)}
    word_map['<unk>'] = len(word_map) + 1
    word_map['<start>'] = len(word_map) + 1
    word_map['<end>'] = len(word_map) + 1
    word_map['<pad>'] = 0
    
    
    return word_map, captions


def create_h5py(directory, path_data, word_map, cap_per_img, max_cap_lenght, sketches_dictn, combine_with_sketches):
    for subset in path_data:
        with h5py.File(os.path.join(directory + '/'+ subset+'_'+base_filename+'.hdf5'), 'w') as h:
            h.attrs['captions_per_image'] = cap_per_img
            if combine_with_sketches:
                images = h.create_dataset('images', (len(path_data[subset]), 6, 256, 256), dtype='uint8')
            else:
                images = h.create_dataset('images', (len(path_data[subset]), 3, 256, 256), dtype='uint8')
    
            enc_captions = []
            caplens = []
            for i, path in enumerate(tqdm(path_data[subset])):
            # read sketch from folder
                img_id = int(path.split('/')[-1].strip('.jpg'))
                
                imcaps = path_data[subset][path]
                
                
                if len(imcaps) < cap_per_img:
                    captions = imcaps + [choice(imcaps) for _ in range(cap_per_img - len(imcaps))]
                elif len(imcaps) == cap_per_img:
                    captions = imcaps
                else:
                    captions = sample(imcaps, k=cap_per_img)
                
                assert len(captions) == cap_per_img
                

                # Read images
                img = Image.open(path)
                img = img.resize((256, 256))
                img = np.transpose(img, (2, 0, 1))
                
                # Below is our attempt of joining the sketches and images.
                # The code is fully functional yet the model is still not adapted
                
#                 if combine_with_sketches:
#                     sk_path = sketches_dictn[img_id]
                    
#                     sk = Image.open(sk_path)
#                     sk = sk.resize((256, 256))
#                     sk = np.transpose(sk, (2, 0, 1))
                    
#                     img = torch.from_numpy(img)
#                     img = torch.unsqueeze(img, 0)
                    
#                     sk = torch.from_numpy(sk)
#                     sk = torch.unsqueeze(sk, 0)
                    
#                     img = torch.cat((img, sk), 1)
#                     img = torch.squeeze(img, 0)
#                     img = img.cpu().detach().numpy()
#                     img = np.transpose(img, (0, 1, 2))
                    

                assert np.max(img) <= 255

                # Save image to HDF5 file
                images[i] = img

                # encoding the captions
                for j, c in enumerate(captions):
                    enc_c = [word_map['<start>']]
                    for word in c:
                        if word == '.':
                            enc_c.append(word_map['<end>'])
                        else:
                            enc_c.append(word_map.get(word, word_map['<unk>']))
                    
                    additional = [word_map['<pad>'] for i in range((max_cap_lenght - len(c)))]
                    enc_c.extend(additional)
                    

                    # Find caption lengths
                    # because of start/end tokens, it is +2
                    c_len = len(enc_c) + 2

                    enc_captions.append(enc_c)
                    caplens.append(len(c))
            
            # Sanity check
            assert images.shape[0] * cap_per_img == len(enc_captions) == len(caplens)
            

            # Save encoded captions and their lengths to JSON files
            with open(os.path.join(out_path, subset + '_CAPTIONS_' + base_filename + '.json'), 'w') as j:
                json.dump(enc_captions, j)
            with open(os.path.join(out_path, subset + '_CAPLENS_' + base_filename + '.json'), 'w') as j:
                json.dump(caplens, j)

If you are plannin to use either `flickr8k` or *Distorted* data, use the cells below to rename and create files in desired format for subsequent code. 

In [None]:
do_files_rename(f'{main_path}/data', 'freakshow_10k', 'freakshow')

In [None]:
if mix_data: 
    path_to_old = '/Users/evelsve/repos/cap/data/flickr8k'
    path_to_new = '/Users/evelsve/repos/cap/data/flickr'
    
    rename_and_create_flickr(main_path, path_to_old, path_to_new)

In [None]:
if mix_data: 
    data, data_id = mixing_read_create_dataset(f'{main_path}')
elif partial_anton_captions:
    data, data_id = partial_augment_read_create_dataset(f'{main_path}')
elif full_anton_captions:
    data, data_id = read_create_dataset(f'{main_path}', augment=True)
else:
    data, data_id = read_create_dataset(f'{main_path}')

path_data = dict()
path_data['TRAIN'], path_data['TEST'], path_data['VAL']  = dict(), dict(), dict()

for path, captions in data.items():
    # take path, get img name, remove jpg, convert to int
    img_id = int(path.split('/')[-1].strip('.jpg'))
    if img_id < 8000:
        path_data['TRAIN'][path] = captions
    elif 8000 <= img_id < 9000:
        path_data['VAL'][path] = captions
    elif 9000 <= img_id < 10000:
        path_data['TEST'][path] = captions
        
print('Step 1: Read data done.')
    
word_map, all_captions = create_word_map(data, max_cap_lenght)

with open(os.path.join(out_path,'wordmap_'+base_filename+'.json'), 'w') as j:
        json.dump(word_map, j)
        
        
print('Step 2: Wordmap creation done.')
        
sketches_dictn = dict()

if combine_with_sketches:
    sketches_path = "/Users/evelsve/repos/cap/data/sketches"
    for f_name in listdir(sketches_path):
        img_id = int(f_name.split('/')[-1].strip('.jpg'))
        path = sketches_path + '/' + f_name
        sketches_dictn[img_id] = path
        print('Step 1: Sketches path read done.')
        print(f'INFO: Lenght of sketches: {len(sketches_dictn)}')
        
        
print(f"INFO: In train: {len(path_data['TRAIN'])}\n In test: {len(path_data['TEST'])} \n In val: {len(path_data['VAL'])}")

print(f'INFO: Lenght of the vocabulary: {len(word_map)}')

create_h5py(main_path, path_data, word_map, cap_per_img, max_cap_lenght, sketches_dictn, combine_with_sketches)

print('Step 3: Files for training, testing and validation created. END')

If one wants to analyze the output or to make sure everything went smoothely, the code below can be used. 

In [None]:
import h5py
# filename = "data/outputs/specific_output_folder/file_to_check.hdf5"

with h5py.File(filename, "r") as f:
    # List all groups
    print("Keys: %s" % f.keys())
    a_group_key = list(f.keys())[0]

    # Get the data
    data = list(f[a_group_key])

In [None]:
data[200][0]

In [None]:
data_to_check = get_dict_from_file()