In [None]:
#!/usr/bin/env python
# coding: utf-8

from pathlib import Path
from tqdm import tqdm
import numpy as np
import gzip
import torchaudio
from pydub import AudioSegment
from io import BytesIO
import os
import IPython.display as ipd
from PIL import Image
from matplotlib import image
from matplotlib import pyplot

def test_image_sample(im, name):

    this_im = Image.open(im)
    pyplot.imshow(this_im)
    pyplot.yticks([])
    pyplot.xticks([])
    pyplot.title(name)
    pyplot.show()
    choice = input()
    
    while choice not in ['y', 'n']:
        choice = input()
    pyplot.close()
    return choice

english = np.load(Path("data/english_me_dataset.npz"), allow_pickle=True)['dataset'].item()

sample_n_images = 10
sample_n_audio = 10

unseen = set()
with open(Path('data/unseen.txt'), 'r') as f:
    for line in f:
        unseen.add(line.strip())
seen = set()
with open(Path('data/seen.txt'), 'r') as f:
    for line in f:
        seen.add(line.strip())
limit = 50
num = 10
unseen_limit = 100

keywords = set()
with open(Path('data/concepts.txt'), 'r') as f:
    for line in f:
        word = line.strip().split()[0]
        keywords.add(word)

train = {}
dev = {}
test = {}

for c in seen:
    
    i_choices = english[c]['images']
    e_choices = english[c]['audio']
    
    test[c] = {'images': [], 'english': []}
    dev[c] = {'images': [], 'english': []}
    train[c] = {'images': [], 'english': []}
    
    im  = []
    test_sanity = []
    i_choices = list(set(i_choices))
    np.random.shuffle(i_choices)
    for a in i_choices:
        temp = Path(str(a).split('_')[0] + '_masked_' + str(a).split('_')[1])
        if temp.is_file() and len(im) < sample_n_images:
            choice = test_image_sample(temp, f'{len(im)}/{sample_n_images} {c}')
            if choice == 'y':
                im.append(temp)
                test_sanity.append(a)
    
#     im = np.random.choice(i_choices, sample_n_images, replace=False)
    eng = np.random.choice(e_choices, sample_n_audio, replace=False)
    
    for entry in im: test[c]['images'].append(entry)
    for entry in eng: test[c]['english'].append(entry)
        
    i_choices = list(set(i_choices) - set(test_sanity))
    np.random.shuffle(i_choices)
    e_choices = list(set(e_choices) - set(eng))
    
    im  = []
    dev_sanity = []
    for a in i_choices:
        temp = Path(str(a).split('_')[0] + '_masked_' + str(a).split('_')[1])
        if temp.is_file() and len(im) < sample_n_images:
            choice = test_image_sample(temp, f'{len(im)}/{sample_n_images} {c}')
            if choice == 'y':
                im.append(temp)
                dev_sanity.append(a)
            
#     im = np.random.choice(i_choices, sample_n_images, replace=False)
    eng = np.random.choice(e_choices, sample_n_audio, replace=False)
    for entry in im: dev[c]['images'].append(entry)
    for entry in eng: dev[c]['english'].append(entry)
        
    im = list(set(i_choices) - set(dev_sanity))
    eng = list(set(e_choices) - set(eng))
    for entry in im: train[c]['images'].append(entry)
    for entry in eng: train[c]['english'].append(entry)


for c in train:
    flag = False
    
    x = set(train[c]['images'])
    y = set(dev[c]['images'])
    z = set(test[c]['images'])
    if len(x.intersection(y)) != 0: flag = True
    if len(x.intersection(z)) != 0: flag = True
    if len(y.intersection(z)) != 0: flag = True   
        
    x = set(train[c]['english'])
    y = set(dev[c]['english'])
    z = set(test[c]['english'])
    if len(x.intersection(y)) != 0: flag = True
    if len(x.intersection(z)) != 0: flag = True
    if len(y.intersection(z)) != 0: flag = True
        
    if flag: print(c)


unseen_test = {}
for c in unseen:

    if len(english[c]['images']) != 0:
        unseen_test[c] = {
            'images': english[c]['images'], 
            'english': english[c]['audio'],
        }
    else: print(c)


remove = set()
for c in train:
    x = len(train[c]['images'])
    y = len(dev[c]['images'])
    z = len(test[c]['images'])
    if y != sample_n_images or z != sample_n_images:
        remove.add(c)
        
for c in remove:
    train.pop(c)
    dev.pop(c)
    test.pop(c)
    seen.remove(c)

for c in train:
    x = len(train[c]['images'])
    y = len(dev[c]['images'])
    z = len(test[c]['images'])
    total = x + y + z
    print(c, x, y, z)

np.savez_compressed(
    Path("data/splited_me_dataset"), 
    train=train, 
    dev=dev,
    test=test,
    unseen_test=unseen_test
    )

print(len(unseen_test), len(test))
a = train
for c in a:
    print(c, len(a[c]['images']), len(a[c]['english']))

In [None]:
dutch_vocab = set()
french_vocab = set()
translations = {}
with open(Path('data/concepts_filtered.txt'), 'r') as f:
    for line in f:
        english, dutch, french = line.split()
#         print(english)
        dutch_vocab.add(dutch)
        french_vocab.add(french)
        if english not in translations:
            translations[english] = {
                'dutch': dutch,
                'french': french
            }
dutch = np.load(Path("data/dutch_me_dataset.npz"), allow_pickle=True)['dataset'].item()
french = np.load(Path("data/french_me_dataset.npz"), allow_pickle=True)['dataset'].item()        
for c in seen:
    
    d = translations[c]['dutch']
    f = translations[c]['french']
    print(d, f)
    
    d_choices = dutch[d]['audio']
    f_choices = french[f]['audio']
    
    test[c]['dutch'] = []
    dev[c]['dutch'] = []
    train[c]['dutch'] = []
    
    test[c]['french'] = []
    dev[c]['french'] = []
    train[c]['french'] = []
  
    
    dut = np.random.choice(d_choices, sample_n_audio, replace=False)
    fre = np.random.choice(f_choices, sample_n_audio, replace=False)
    
    for entry in dut: test[c]['dutch'].append(entry)
    for entry in fre: test[c]['french'].append(entry)
        
    d_choices = list(set(d_choices) - set(dut))
    f_choices = list(set(f_choices) - set(fre))
    
            
    dut = np.random.choice(d_choices, sample_n_audio, replace=False)
    fre = np.random.choice(f_choices, sample_n_audio, replace=False)
    for entry in dut: dev[c]['dutch'].append(entry)
    for entry in fre: dev[c]['french'].append(entry)
        
    dut = list(set(d_choices) - set(dut))
    fre = list(set(f_choices) - set(fre))
    for entry in dut: train[c]['dutch'].append(entry)
    for entry in fre: train[c]['french'].append(entry)

for c in train:
    flag = False
    
    x = set(train[c]['images'])
    y = set(dev[c]['images'])
    z = set(test[c]['images'])
    if len(x.intersection(y)) != 0: flag = True
    if len(x.intersection(z)) != 0: flag = True
    if len(y.intersection(z)) != 0: flag = True
        
        
        
    x = set(train[c]['english'])
    y = set(dev[c]['english'])
    z = set(test[c]['english'])
    if len(x.intersection(y)) != 0: flag = True
    if len(x.intersection(z)) != 0: flag = True
    if len(y.intersection(z)) != 0: flag = True
        
        
        
    x = set(train[c]['dutch'])
    y = set(dev[c]['dutch'])
    z = set(test[c]['dutch'])
    if len(x.intersection(y)) != 0: flag = True
    if len(x.intersection(z)) != 0: flag = True
    if len(y.intersection(z)) != 0: flag = True
        
        
        
    x = set(train[c]['french'])
    y = set(dev[c]['french'])
    z = set(test[c]['french'])
    if len(x.intersection(y)) != 0: flag = True
    if len(x.intersection(z)) != 0: flag = True
    if len(y.intersection(z)) != 0: flag = True
        
    if flag: print(c)

np.savez_compressed(
    Path("data/splited_me_dataset"), 
    train=train, 
    dev=dev,
    test=test,
    unseen_test=unseen_test
    )

print(len(unseen_test), len(test))
a = train
for c in a:
    print(c, len(a[c]['images']), len(a[c]['english']), len(a[c]['dutch']), len(a[c]['french']))