In [1]:
import argparse
import torch
from models.setup import *
from models.GeneralModels import *
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import scipy
import scipy.signal
import librosa
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

In [2]:
def modelSetup(parser, test=False):

    config_file = parser.pop("config_file")
    print(f'configs/{config_library[config_file]}')
    with open(f'configs/{config_library[config_file]}') as file:
        args = json.load(file)

    image_base = parser.pop("image_base")

    for key in parser:
        args[key] = parser[key]

    args["data_train"] = Path(args["data_train"])
    args["data_val"] = Path(args["data_val"])
    args["data_test"] = Path(args["data_test"])

    getDevice(args)

    return args, image_base

In [3]:
command_line_args = {
    "resume": False, 
    "config_file": 'multilingual+matchmap',
    "device": "0", 
    "restore_epoch": -1, 
    "image_base": ".."
}

In [4]:
args, image_base = modelSetup(command_line_args)

configs/English_Hindi_matchmap_DAVEnet_config.json


In [5]:
image_labels = np.load(Path('data/gold_image_to_labels.npz'), allow_pickle=True)['image_labels'].item()
labels_to_images = np.load(Path('data/gold_labels_to_images.npz'), allow_pickle=True)['labels_to_images'].item()

In [6]:
key = {}
for i, l in enumerate(sorted(labels_to_images)):
    key[l] = i

In [7]:
with open(args["data_train"], 'r') as fp:
    data = json.load(fp)
image_base_path = Path(image_base).absolute()

id_lookup = {}

for fn in data:
    data_point = np.load(fn + ".npz")
    ids = np.unique(image_labels[fn.split('/')[-1].split('+')[0]])
    ids = list(ids)
    for i in ids:
        id = key[i]
        if id not in id_lookup:
            id_lookup[id] = []
        id_lookup[id].append(fn)

In [8]:
neg_id_lookup = {}
for id in tqdm(sorted(id_lookup)):
    images_with_id = id_lookup[id].copy()
        
    all_ids = list(id_lookup.keys())
    all_ids.remove(id)
    
    neg_id_lookup[id] = {}
    
    for neg_id in all_ids:
        temp = [i for i in id_lookup[neg_id] if i not in images_with_id]
        if len(temp) > 0:
            neg_id_lookup[id][neg_id] = temp

100%|██████████| 4015/4015 [11:11<00:00,  5.98it/s]  


In [9]:
neg_id_lookup[0]

{14: ['data/flickr/2513260012_03d33305cf_0+SPEAKER_46',
  'data/flickr/2903617548_d3e38d7f88_1+SPEAKER_14',
  'data/flickr/2903617548_d3e38d7f88_0+SPEAKER_7',
  'data/flickr/3338291921_fe7ae0c8f8_3+SPEAKER_83',
  'data/flickr/3338291921_fe7ae0c8f8_2+SPEAKER_86',
  'data/flickr/3338291921_fe7ae0c8f8_1+SPEAKER_22',
  'data/flickr/3338291921_fe7ae0c8f8_0+SPEAKER_9',
  'data/flickr/3338291921_fe7ae0c8f8_4+SPEAKER_46',
  'data/flickr/488416045_1c6d903fe0_2+SPEAKER_62',
  'data/flickr/488416045_1c6d903fe0_1+SPEAKER_88',
  'data/flickr/488416045_1c6d903fe0_0+SPEAKER_96',
  'data/flickr/2644326817_8f45080b87_2+SPEAKER_61',
  'data/flickr/2644326817_8f45080b87_1+SPEAKER_96',
  'data/flickr/2644326817_8f45080b87_0+SPEAKER_61',
  'data/flickr/218342358_1755a9cce1_1+SPEAKER_82',
  'data/flickr/218342358_1755a9cce1_0+SPEAKER_42',
  'data/flickr/218342358_1755a9cce1_3+SPEAKER_3',
  'data/flickr/218342358_1755a9cce1_2+SPEAKER_125',
  'data/flickr/2501968935_02f2cd8079_0+SPEAKER_5',
  'data/flickr/250

In [10]:
np.savez_compressed(
    Path("./data/train_image_mask_lookup"), 
    lookup=id_lookup,
    neg_lookup=neg_id_lookup
)

In [None]:
id_lookup = {}

with open(args["data_val"], 'r') as fp:
    data = json.load(fp)
image_base_path = Path(image_base).absolute()

for i in ids:
        id = key[i]
        if id not in id_lookup:
            id_lookup[id] = []
        id_lookup[id].append(fn)

In [None]:
neg_id_lookup = {}
for id in sorted(id_lookup):
    images_with_id = id_lookup[id].copy()
        
    all_ids = list(id_lookup.keys())
    all_ids.remove(id)
    
    neg_id_lookup[id] = {}
    
    for neg_id in tqdm(all_ids, desc=f'ID: {id}'):
        temp = [i for i in id_lookup[neg_id] if i not in images_with_id]
        if len(temp) > 0:
            neg_id_lookup[id][neg_id] = temp

In [None]:
np.savez_compressed(
    Path("./data/val_image_mask_lookup"), 
    lookup=id_lookup,
    neg_lookup=neg_id_lookup
)

In [None]:
id_lookup = {}

with open(args["data_test"], 'r') as fp:
    data = json.load(fp)
image_base_path = Path(image_base).absolute()

for i in ids:
        id = key[i]
        if id not in id_lookup:
            id_lookup[id] = []
        id_lookup[id].append(fn)

In [None]:
neg_id_lookup = {}
for id in sorted(id_lookup):
    images_with_id = id_lookup[id].copy()
        
    all_ids = list(id_lookup.keys())
    all_ids.remove(id)
    
    neg_id_lookup[id] = {}
    
    for neg_id in tqdm(all_ids, desc=f'ID: {id}'):
        temp = [i for i in id_lookup[neg_id] if i not in images_with_id]
        if len(temp) > 0:
            neg_id_lookup[id][neg_id] = temp

In [None]:
np.savez_compressed(
    Path("./data/test_image_mask_lookup"), 
    lookup=id_lookup,
    neg_lookup=neg_id_lookup
)