In [1]:
import argparse
import torch
from models.setup import *
from models.GeneralModels import *
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import scipy
import scipy.signal
import librosa
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np

In [2]:
def modelSetup(parser, test=False):

    config_file = parser.pop("config_file")
    print(f'configs/{config_library[config_file]}')
    with open(f'configs/{config_library[config_file]}') as file:
        args = json.load(file)

    image_base = parser.pop("image_base")

    for key in parser:
        args[key] = parser[key]

    args["data_train"] = Path(args["data_train"])
    args["data_val"] = Path(args["data_val"])
    args["data_test"] = Path(args["data_test"])

    getDevice(args)

    return args, image_base

In [3]:
command_line_args = {
    "resume": False, 
    "config_file": 'multilingual+matchmap',
    "device": "0", 
    "restore_epoch": -1, 
    "image_base": ".."
}

In [4]:
args, image_base = modelSetup(command_line_args)

configs/English_Hindi_matchmap_DAVEnet_config.json


In [5]:
base = Path('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset')

In [6]:
base.is_dir()

True

In [7]:
images = list(base.rglob('*.jpg'))

In [8]:
images[0:10]

[PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/2868668723_0741222b23.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/3104690333_4314d979de.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/2924908529_0ecb3cdbaa.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/2557129157_074a5a3128.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/3357937209_cf4a9512ac.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/2587017287_888c811b5a.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/1311132744_5ffd03f831.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/1805990081_da9cefe3a5.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/390992388_d74daee638.jpg'),
 PosixPath('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/1396703063_e8c3687afe.jpg')]

In [9]:
other_base = Path('/mnt/HDD/leanne_HDD/Datasets/Flickr8k_text')

In [10]:
other_base.is_dir()

True

In [11]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
from nltk.tokenize import word_tokenize, sent_tokenize

[nltk_data] Downloading package punkt to /home/leanne/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/leanne/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [12]:
labels_to_images = {}
for line in open(other_base / Path('Flickr8k.token.txt'), 'r'):
    parts = line.strip().split()
    name = parts[0].split('.')[0] + '_' + parts[0].split('#')[-1]
    sentence = ' '.join(parts[1:]).lower()
    tokenized = sent_tokenize(sentence)
    for w in tokenized:

        words = nltk.word_tokenize(w)
        words = [w for w in words]
        tagged = nltk.pos_tag(words)
        for (word, tag) in tagged:
            if tag not in ['NN']: continue
            if word not in labels_to_images: labels_to_images[word] = []
            labels_to_images[word].append(name)

In [13]:
key = {}
id_to_word_key = {}
for i, l in enumerate(sorted(labels_to_images)):
    key[l] = i
    id_to_word_key[i] = l

In [14]:
ids_to_images = {}
for l in labels_to_images:
    id = key[l]
    ids_to_images[id] = labels_to_images[l]

In [15]:
len(labels_to_images)

4248

In [16]:
len(ids_to_images)

4248

In [17]:
for label in ids_to_images.copy():
    if len(ids_to_images[label]) < 20:
        ids_to_images.pop(label)

In [18]:
classes = list(np.random.choice(list(ids_to_images.keys()), size=300, replace=False))

In [19]:
classes

[1683,
 3885,
 431,
 3230,
 3515,
 1934,
 3083,
 2938,
 3424,
 1439,
 833,
 2099,
 1395,
 325,
 3446,
 2071,
 359,
 630,
 3318,
 2621,
 2691,
 2420,
 3728,
 1504,
 2559,
 786,
 2949,
 3500,
 3613,
 1101,
 398,
 422,
 915,
 2866,
 693,
 3528,
 1434,
 4088,
 4141,
 170,
 476,
 967,
 3994,
 2626,
 2746,
 603,
 1543,
 3442,
 2327,
 2260,
 1793,
 2910,
 114,
 3817,
 3486,
 1549,
 3562,
 831,
 1617,
 1123,
 3878,
 2320,
 912,
 252,
 199,
 727,
 3786,
 2593,
 1625,
 658,
 1169,
 3730,
 3013,
 1983,
 3911,
 4219,
 1445,
 232,
 234,
 469,
 3362,
 3897,
 1148,
 4068,
 2233,
 1282,
 250,
 1105,
 1686,
 1249,
 1600,
 2598,
 3298,
 4123,
 4035,
 2215,
 3368,
 3838,
 4019,
 943,
 1899,
 454,
 639,
 2166,
 650,
 3153,
 770,
 3165,
 13,
 416,
 3031,
 1326,
 3534,
 2738,
 2755,
 710,
 1679,
 3845,
 1483,
 3005,
 2578,
 2810,
 479,
 582,
 1470,
 246,
 2928,
 919,
 1727,
 966,
 2079,
 106,
 435,
 3077,
 2059,
 1612,
 2749,
 3596,
 1826,
 2996,
 225,
 3044,
 1708,
 529,
 3229,
 285,
 3768,
 4044,
 2069,
 

In [20]:
len(ids_to_images)

950

In [21]:
image_labels = {}
for l in ids_to_images:
    if l in classes:
        print(l)
        for im in ids_to_images[l]:
            if im not in image_labels: image_labels[im] = []
            image_labels[im].append(l)

1148
3207
1105
2626
2520
2498
1686
2420
3064
2746
1483
194
264
2749
3283
2260
2771
3031
3899
472
3329
1282
3076
1793
2459
2016
1169
3671
757
710
1540
3013
3562
2949
3897
3994
2079
2688
359
1625
2233
1934
2740
2559
841
2866
199
3878
3923
3881
405
1220
1747
3326
978
4219
2593
2278
3786
3782
1727
322
1076
2797
2938
3362
3838
3845
3534
2059
434
1101
2622
4141
1683
880
1431
2621
2620
786
3229
4019
639
3655
661
3596
1439
3935
3972
3005
2166
3424
252
3402
658
589
528
157
3403
2691
592
3613
1476
3918
1249
325
3651
1287
2069
1647
418
2144
214
4111
3828
2738
25
3153
2687
3960
2675
454
273
2215
1187
1277
3083
2377
311
3486
2751
1549
776
3039
4060
4088
3976
1044
1445
170
4084
4123
416
2088
1627
2327
2926
225
3833
3885
2821
4068
2928
3318
2996
1983
250
966
1826
2723
2331
1776
1434
3230
2703
2540
603
1546
3047
967
1504
208
1708
4044
1401
469
727
2318
770
3507
831
3442
630
422
2106
740
1617
2081
1127
2364
149
3165
1095
106
3222
1300
1543
1073
419
374
2320
3368
1910
234
4160
476
1693
3730
2658
2488
37

In [22]:
len(image_labels)

22994

In [23]:
labels_to_images = {}
for im in image_labels:
    for id in image_labels[im]:
        if id not in labels_to_images: labels_to_images[id] = []
        labels_to_images[id].append(im)

In [24]:
len(labels_to_images)

300

In [25]:
np.savez_compressed(
    Path('data/gold_image_to_labels.npz'),
    image_labels=image_labels
)

In [26]:
np.savez_compressed(
    Path('data/gold_labels_to_images.npz'),
    labels_to_images=ids_to_images
)

In [27]:
np.savez_compressed(
    Path('data/gold_label_key.npz'),
    key=key,
    id_to_word_key=id_to_word_key
)

In [28]:
id_to_word = {}
for k in key:
    id_to_word[key[k]] = k

In [29]:
p = []
for id in ids_to_images:
    p.append(id_to_word[id])

In [30]:
for i in sorted(p):
    print(i)

accordion
action
adult
advertisement
agility
air
airplane
alley
ambulance
amusement
animal
apple
apron
area
arena
arm
army
art
artist
asleep
athlete
attention
attire
atv
audience
autumn
baby
back
background
backpack
backpacker
backyard
bag
balance
balcony
bald
ball
balloon
band
bandanna
bank
banner
bar
bare
barefoot
barking
barn
barrel
barren
barrier
base
baseball
basket
basketball
bat
bath
bathing
bathroom
bathtub
batman
bay
beach
beagle
beam
bear
beard
bed
bee
beer
behind
beige
bench
bend
beside
beverage
bicycle
bicyclist
bike
biker
biking
bikini
bikinis
bird
birthday
blanket
block
blond
blonde
blow
blowing
blue
blurry
bmx
board
boardwalk
boat
body
bone
boogie
book
booth
bottle
bottom
boulder
bounce
bouncy
bow
bowl
bowling
box
boxer
boxing
boy
branch
bread
break
brick
bride
bridge
brother
brown
brunette
brush
bubble
bucket
building
bull
bunch
bus
bush
business
cafe
cage
cake
camel
camera
camouflage
campfire
candy
cannon
canoe
canyon
cap
cape
car
cardboard
carnival
carpet
carriage
car