In [None]:
import os
import logging

from PIL import Image
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import json

import nltk
nltk.download('wordnet')
nltk.download('punkt')
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize

In [None]:
# we get the data cloning our github repository

if not os.path.isdir('./COCO'):
  !git clone https://github.com/MarcoSaponara/MLAI_LINKS_project.git
  !mv 'Conditional_Text_Generation_Project' 'CTRL'

with open("CTRL/annotations_train_val_2014/captions_train2014.json","r") as f:
  train_dataset = json.load(f)

with open("CTRL/annotations_train_val_2014/captions_val2014.json","r") as f:
  val_dataset = json.load(f)

with open("CTRL/annotations_test_2014/image_info_test2014.json","r") as f:
  test_dataset = json.load(f)

In [None]:
ds = CocoDataset(path='COCO/annotations_train_val_2014/captions_train2014.json', text_field=data.Field())

In [None]:
# dataset description

print(train_dataset.keys())
print(" info is a dict that gives info about the dataset \n licenses is a list of licenses related to the source of the images ")
print("\n")
print("images is a list of dictionaries , each dict is a photo and contains basic info like the url, dimensions and id")
print(train_dataset['images'][0].keys(),len(train_dataset['images']), " elements")
print("\n")
print("annotations is a list of dictionaries, each dict is a caption")
print(train_dataset['annotations'][0].keys(),len(train_dataset['annotations']), " elements")
print("there are much more captions than images so each image has more than one caption")
print("\n")
print("test set is the same but instead of the annotations it provides us the categories of the images it contains")
print(test_dataset.keys(), len(test_dataset['images']))
print("categories is a list of dictionaries, each dict is one of the categories")
print(test_dataset['categories'][0].keys())
print("supercategory is a general category while name is a more specific one")
print("example: ",test_dataset['categories'][1])

In [None]:
N_images = len(train_dataset['images'])
lista_im = []
for i in range(N_images):
  lista_im.append(train_dataset['images'][i]['id'])
lista_im.sort()

In [None]:
N_capts = len(train_dataset['annotations'])
lista_cap = []
lista_im_cap = []
for i in range(N_capts):
  lista_cap.append(train_dataset['annotations'][i]['caption'])
  lista_im_cap.append(train_dataset['annotations'][i]['image_id'])

In [None]:
immagini = pd.DataFrame(data=lista_im, columns=['image_id']) # DF with images' IDs
captions = pd.DataFrame(data=lista_im_cap, columns=['image_id']) # DF with... 
captions['capt'] = lista_cap #... captions and their corresponding image
capt_image = pd.merge(captions,immagini) # DF with all the captions of all the images
capt_image_group = capt_image.groupby('image_id') # grouped DF by image id
capt_image.head(15)

In [None]:
names = set()
for item in test_dataset['categories']:
  names.add(item['name'])
print(names)
supercategories = set()
for item in test_dataset['categories']:
  supercategories.add(item['supercategory'])
print(supercategories)

In [None]:
full_objects_list = pd.read_csv("COCO/full_objects_list.txt", sep=';')
full_objects_list = full_objects_list.rename(columns={'ID':'id', 'Object (Paper)':'name', 'Super Category':'supercategory'})

In [None]:
def check_caption(cap):
  ids = list() # to store the ids of the categories related to the caption

  #for cat in test_dataset['categories']:
  for i, cat in full_objects_list.iterrows():
    if cat['supercategory'] == 'person':
      synonyms = list()
      synonyms.append('person')
      synonyms.append('man')
      synonyms.append('woman')
      synonyms.append('couple')
      synonyms.append('group')
      synonyms.append('people')
      synonyms.append('girl')
      synonyms.append('boy')
      
      for syn in synonyms:
        if syn in cap:
          ids.append(cat['id'])
          break

    else:
      if cat['supercategory'] in cap:
        ids.append(cat['id'])
      
  if ids is not None: # look for names in categories
    #for cat in test_dataset['categories']:
    for i, cat in full_objects_list.iterrows():
      if cat['name'] in cap:
        ids.append(cat['id'])

  return ids

def lemmarize(text):
  wnl = WordNetLemmatizer()
  tokens = [token.lower() for token in word_tokenize(text)]
  lemmatized_words = [wnl.lemmatize(token) for token in tokens]
  return lemmatized_words

In [None]:
for annot in train_dataset['annotations']:
  annot['lab_ids'] = check_caption(lemmarize(annot['caption']))

In [None]:
unlabeled = 0
for an in train_dataset['annotations']:
  if not an['lab_ids']:
    unlabeled = unlabeled + 1

print('unlabeled captions: ' + str(100*unlabeled/len(train_dataset['annotations']))+'%')

unlabeled captions: 18.2071077218054%


In [None]:
catlabs = {new_list: [] for new_list in full_objects_list['id']}

for annot in train_dataset['annotations']:
  for id in annot['lab_ids']:
    catlabs[id].append(annot['id'])

In [None]:
with open('./train_labels.json', 'w') as fp:
    json.dump(catlabs, fp)

### DATASET GENERATION

In [None]:
with open('CTRL/annotations_train_val_2014/train_labels.json','r') as f:
  ds = json.load(f)

In [None]:
len(ds['1']) # size of dataset with label 'person'

164688

In [None]:
# creating datasets for training
for k in range(4):
  output_dict = [x['caption'] for x in train_dataset['annotations'] if x['id'] in ds['1'][k*10000:(k+1)*10000]]
  with open('CTRL/dataset10k_ +' + str(k+1) + '.txt','w') as f:
    for line in output_dict:
      f.write(line)



In [None]:
# creating dataset with prompts
output_dict = [x['caption'] for x in train_dataset['annotations'] if x['id'] in ds['1'][-1000:]]
with open('CTRL/prompts.txt','w') as f:
  for line in output_dict:
    line = line.split(' ')
    f.write('Person '+ line[0] + ' ' + line[1]+'\n')

In [None]:
# creating sample dataset for evaluation
output_dict = [x['caption'] for x in train_dataset['annotations'] if x['id'] in ds['1'][40000:41000]]
with open('CTRL/sample_dataset.txt','w') as f:
  for line in output_dict:
    f.write(line+'\n')