forked from weitat95/KnowingWhereToLook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
146 lines (119 loc) · 6.65 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import json
from collections import Counter
import h5py
from random import seed, choice, sample
from tqdm import tqdm
from scipy.misc import imread, imresize
import numpy as np
#Load Karpathy's Split.
with open('../data/dataset_coco.json', 'r') as j:
data = json.load(j)
#data_loc = '../data/Flicker8k_Dataset' #change when trainin on GPU to disk scratch
data_loc = '../data/all_coco_img/'
#Define the lists to store the images paths and the captions of the image
train_image_paths = []
train_image_captions = []
val_image_paths = []
val_image_captions = []
test_image_paths = []
test_image_captions = []
word_freq = Counter()
captions_per_image=5
min_word_freq=5
max_len=50
#Let's visualize how the data is arranged
i=0
for img in data['images']:
if i==1: break
print(img['filename'])
for c in img['sentences']:
print(c['tokens'])
i=1
#Create the lists. The image_paths lists contains the path of every image, and the image_captions contains the corresponding
#captions (can be more than one) for each image. Therefore, the lengths of the two should be identical.
for img in data['images']: #Start by looping through every image in Karpathy's split
captions = [] #Define a list to append the captions of that specific image
for c in img['sentences']: #Loop through every sentence of the image
# Update word frequency
word_freq.update(c['tokens'])
if len(c['tokens']) <= max_len: #Make sure the sentence isn't too long, and ignore it if so
captions.append(c['tokens']) #Append the list of words of the sentence to the list. len(captions) = 5
#Get the path of the image
path = os.path.join(data_loc, img['filename']) #path if using the Flickr dataset
if img['split'] in {'train', 'restval'}: #Get what karpathy's split is
train_image_paths.append(path) #Append the image path to the image_paths list
train_image_captions.append(captions) #Append the captions list (of size 5) to the image_captions list
elif img['split'] in {'val'}:
val_image_paths.append(path)
val_image_captions.append(captions)
elif img['split'] in {'test'}:
test_image_paths.append(path)
test_image_captions.append(captions)
#Check that everything is OK
assert len(train_image_paths) == len(train_image_captions)
assert len(val_image_paths) == len(val_image_captions)
assert len(test_image_paths) == len(test_image_captions)
print("Words before filtering: ",len(word_freq.keys()))
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
print("Words after filtering: ",len(words))
#Create the word2index dictionary (maps a word to a unique index)
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0
#Save the wordmap file to drive
with open('WORDMAP.json', 'w') as j:
json.dump(word_map, j)
for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'),
(val_image_paths, val_image_captions, 'VAL'),
(test_image_paths, test_image_captions, 'TEST')]:
with h5py.File(os.path.join('caption data', split + '_IMAGES_' + '.hdf5'), 'a') as h:
# Make a note of the number of captions we are sampling per image
h.attrs['captions_per_image'] = captions_per_image
# Create dataset inside HDF5 file to store images
images = h.create_dataset('images', (len(impaths), 3, 224, 224), dtype='uint8')
print("\nReading {} images and captions, storing to file...\n".format(split))
enc_captions = []
caplens = []
for i, path in enumerate(tqdm(impaths)):
# Sample captions
# If the image includes less than 5 captions, then complete it to 5 from any of the captions. Else, just shuffle
# the 5 captions to different order using random.sample
if len(imcaps[i]) < captions_per_image:
captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))]
else:
captions = sample(imcaps[i], k=captions_per_image)
#Check that everything is OK
assert len(captions) == captions_per_image
# Read images
#img = imread(impaths[i])
#img = imresize(img, (224, 224)) # Resize the image to the expected size
#img = img.transpose(2, 0, 1) # make the channels (index 2) first,as expected by PyTorch
img = imread(impaths[i])
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img = np.concatenate([img, img, img], axis=2)
img = imresize(img, (224, 224))
img = img.transpose(2, 0, 1)
assert img.shape == (3, 224, 224) # raise an error if the image shape is not (3,224,224)
# Save image to HDF5 file
images[i] = img
for c in captions: # For every caption of the 5 captions
#Encode captions
#Use the dict.get(key, default = None)
#key: This is the Key to be searched in the dictionary
#default: This is the Value to be returned in case key does not exist
enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in c] + [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(c))
# Find caption lengths without the padding
c_len = len(c) + 2
enc_captions.append(enc_c) #Append the list of encoded caption to the list enc_captions
caplens.append(c_len) #Append the caption length to the list c_len
#Make sure everything is correct. Remember when we initialized images: images.shape[0] = len(impaths)
assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens)
# Save encoded captions and their lengths to JSON files
with open(os.path.join('caption data', split + '_CAPTIONS_' + '.json'), 'w') as j:
json.dump(enc_captions, j)
with open(os.path.join('caption data', split + '_CAPLENS_' + '.json'), 'w') as j:
json.dump(caplens, j)