In [1]:
import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
# from dataloaders import *
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import apex
from apex import amp
import time
import json
import hashlib
import os
import pickle
import numpy as np
from torchvision import transforms
import scipy
import scipy.signal
import librosa

In [2]:
with open('data/flickr8k.pickle', "rb") as f:
    data = pickle.load(f)

In [3]:
data['test']

[{'trn': ['boy', 'striped', 'shirt', 'hat', 'tricks', 'steps'],
  'soft': array([8.1566989e-01, 4.3871049e-05, 7.1803479e-05, 1.5985583e-04,
         2.5073934e-04, 1.3776307e-01, 6.9209331e-01, 5.4140845e-03,
         1.4499205e-01, 5.8996817e-03, 4.3439548e-05, 2.0530540e-04,
         6.6964584e-04, 1.0646762e-03, 3.6099073e-04, 3.0624356e-05,
         2.2992153e-04, 3.5290612e-04, 5.8179966e-04, 3.4575725e-03,
         3.7258002e-03, 2.6236510e-02, 6.6189757e-03, 3.0220211e-02,
         2.3317397e-01, 1.6925303e-02, 1.2032141e-02, 4.4727148e-04,
         1.6743212e-04, 1.3322881e-04, 2.9056987e-03, 5.2010711e-02,
         3.4697754e-03, 7.4023154e-04, 1.3756833e-05, 1.9058926e-02,
         4.8652884e-02, 5.8390290e-01, 6.1459425e-03, 8.2352022e-03,
         1.4010059e-03, 1.9075305e-05, 2.2038858e-01, 1.2965593e-03,
         1.9881052e-03, 9.8864728e-01, 5.0721116e-02, 5.9001410e-04,
         9.8790624e-04, 3.3197258e-04, 2.0278439e-05, 2.6430390e-03,
         8.7492066e-05, 2.98201

In [4]:
vocab = []
with open('data/keywords.txt', "r") as f:
    for line in f: vocab.append(line.strip())
print(len(vocab))

67


In [5]:
words = set()

for i in range(len(data['test'])):
    for w in data['test'][i]['trn']:
        if w in vocab:
            words.add(w)

In [6]:
words

{'air',
 'baby',
 'ball',
 'beach',
 'bike',
 'black',
 'boy',
 'brown',
 'building',
 'camera',
 'car',
 'carrying',
 'children',
 'climbing',
 'dirt',
 'dogs',
 'face',
 'field',
 'football',
 'grass',
 'hair',
 'hat',
 'holding',
 'jacket',
 'jumps',
 'large',
 'little',
 'mountain',
 'mouth',
 'ocean',
 'orange',
 'park',
 'pink',
 'pool',
 'race',
 'red',
 'rides',
 'riding',
 'road',
 'rock',
 'running',
 'sand',
 'shirt',
 'sits',
 'sitting',
 'skateboard',
 'small',
 'smiling',
 'snow',
 'snowy',
 'soccer',
 'stands',
 'stick',
 'street',
 'swimming',
 'tennis',
 'three',
 'top',
 'toy',
 'tree',
 'walks',
 'water',
 'wearing',
 'white',
 'women',
 'yellow',
 'young'}

In [7]:
len(words)

67

In [8]:
cat_ids_to_labels = np.load(Path("data/mask_cat_id_labels.npz"), allow_pickle=True)['cat_ids_to_labels'].item()

In [9]:
for key in sorted(cat_ids_to_labels):
    print(key, cat_ids_to_labels[key])

1 person
2 bicycle
3 car
4 motorcycle
5 airplane
6 bus
7 train
8 truck
9 boat
10 traffic light
11 fire hydrant
13 stop sign
14 parking meter
15 bench
16 bird
17 cat
18 dog
19 horse
20 sheep
21 cow
22 elephant
23 bear
24 zebra
25 giraffe
27 backpack
28 umbrella
31 handbag
32 tie
33 suitcase
34 frisbee
35 skis
36 snowboard
37 sports ball
38 kite
39 baseball bat
40 baseball glove
41 skateboard
42 surfboard
43 tennis racket
44 bottle
46 wine glass
47 cup
48 fork
49 knife
50 spoon
51 bowl
52 banana
53 apple
54 sandwich
55 orange
56 broccoli
57 carrot
58 hot dog
59 pizza
60 donut
61 cake
62 chair
63 couch
64 potted plant
65 bed
67 dining table
70 toilet
72 tv
73 laptop
74 mouse
75 remote
76 keyboard
77 cell phone
78 microwave
79 oven
80 toaster
81 sink
82 refrigerator
84 book
85 clock
86 vase
87 scissors
88 teddy bear
89 hair drier
90 toothbrush
92 banner
93 blanket
95 bridge
100 cardboard
107 counter
109 curtain
112 door-stuff
118 floor-wood
119 flower
122 fruit
125 gravel
128 house
130 lig

In [10]:
vocab_ids = {
    'air': [187],
    'baby': [1],
    'ball': [37],
    'beach': [155, 154],
    'bike': [2],
    'boy': [1],
    'building': [197, 128],
    'car': [3],
    'children': [1],
    'climbing': [192, 161],
    'dirt': [194],
    'dogs': [18],
    'face': [1],
    'field': [193, 145],
    'football': [37],
    'grass': [193],
    'hair': [1, 89],
    'mountain': [192],
    'mouth': [1],
    'ocean': [155],
    'orange': [55],
    'park': [145],
    'pool': [178],
    'rides': [2, 3, 4, 5, 6, 7, 8, 9],
    'riding': [2, 3, 4, 5, 6, 7, 8, 9],
    'road': [149],
    'rock': [198],
    'sand': [154],
    'sits': [15, 62, 63],
    'sitting': [15, 62, 63],
    'skateboard': [41],
    'smiling': [1],
    'snow': [159],
    'snowy': [159],
    'soccer': [37],
    'street': [149],
    'swimming': [],
    'tennis': [],
    'toy': [88],
    'tree': [184],
    'water': [178],
    'women': [1]
}

In [11]:
vocab = [key for key in vocab_ids]

In [12]:
samples = data['test']

In [13]:
image_base = Path('/mnt/HDD/leanne_HDD/Datasets/Flicker8k_Dataset/')
print(image_base.is_dir())

True


In [14]:
config_library = {
    "multilingual": "English_Hindi_DAVEnet_config.json",
    "multilingual+matchmap": "English_Hindi_matchmap_DAVEnet_config.json",
    "english": "English_DAVEnet_config.json",
    "english+matchmap": "English_matchmap_DAVEnet_config.json",
    "hindi": "Hindi_DAVEnet_config.json",
    "hindi+matchmap": "Hindi_matchmap_DAVEnet_config.json",
}

def modelHash(args):

    exclude_keys = ["resume"]

    name_dict = args.copy()
    name_dict.pop("resume", None)

    args["model_name"] = hashlib.md5(repr(sorted(name_dict.items())).encode("ascii")).hexdigest()[:10]

def modelSetup(parser, test=False):

    config_file = parser.pop("config_file")
    print(f'configs/{config_library[config_file]}')
    with open(f'configs/{config_library[config_file]}') as file:
        args = json.load(file)

    if "restore_epoch" in parser:
        restore_epoch = parser.pop("restore_epoch")
    if "resume" in parser:
        resume = parser.pop("resume")
    else: 
        resume = False
    if "feat" in parser:
        feat = parser.pop("feat")
    else:
        feat = None
    if "dataset_path" in parser:
        dataset_path = parser.pop("dataset_path")
    else: 
        dataset_path = None
    if "base_path" in parser:
        base_path = parser.pop("base_path")
    else: 
        base_path = None
    image_base = parser.pop("image_base")
    device = parser.pop("device")

    for key in parser:
        args[key] = parser[key]

    args["data_train"] = Path(args["data_train"])
    args["data_val"] = Path(args["data_val"])
    args["data_test"] = Path(args["data_test"])

    modelHash(args)

    base_dir = Path("model_metadata")    
    data = "_".join(str(Path(os.path.basename(args["data_train"])).stem).split("_")[0:4])
    model_particulars = f'AudioModel-{args["audio_model"]["name"]}_ImageModel-{args["image_model"]}_ArgumentsHash-{args["model_name"]}_ConfigFile-{Path(config_library[config_file]).stem}' 
    args["exp_dir"] = base_dir / data / model_particulars

    if test or resume:

        print(f'\nRecovering model arguments from')

        print((args["exp_dir"] / "args.pkl").absolute())
        assert(os.path.isfile((args["exp_dir"] / "args.pkl").absolute()))
        with open(args["exp_dir"] / "args.pkl", "rb") as f:
            args = pickle.load(f)
        
        for key in parser:
            args[key] = parser[key]

        if restore_epoch != -1: args["restore_epoch"] = restore_epoch
        args["resume"] = resume
        if dataset_path is not None: args["dataset_path"] = dataset_path
        if base_path is not None: args["base_path"] = base_path

    else:
        assert(os.path.isfile(args["exp_dir"]) is False)
        printDirectory(args["exp_dir"])
        print(f'Saving model arguments at:')

        os.makedirs(args["exp_dir"])
        with open(args["exp_dir"] / "args.pkl", "wb") as f:
            pickle.dump(args, f)
        args["resume"] = False
    args["device"] = device
    if feat is not None:
        args['feat'] = feat

    return args, image_base

In [15]:
command_line_args = {
    "resume": True, 
    "config_file": 'multilingual+matchmap',
    "device": "0", 
    "restore_epoch": -1, 
    "image_base": ".."
}
args, image_base = modelSetup(command_line_args)

configs/English_Hindi_matchmap_DAVEnet_config.json

Recovering model arguments from
/home/leanne/KWS/model_metadata/flickr_train/AudioModel-Transformer_ImageModel-Resnet50_ArgumentsHash-1572865793_ConfigFile-English_Hindi_matchmap_DAVEnet_config/args.pkl


In [26]:
scipy_windows = {
    'hamming': scipy.signal.hamming,
    'hann': scipy.signal.hann, 
    'blackman': scipy.signal.blackman,
    'bartlett': scipy.signal.bartlett
    }

categories = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 92, 93, 95, 100, 107, 109, 112, 118, 119, 122, 125, 128, 130, 133, 138, 141, 144, 145, 147, 148, 149, 151, 154, 155, 156, 159, 161, 166, 168, 171, 175, 176, 177, 178, 180, 181, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200]

def preemphasis(signal,coeff=0.97):  
    # function adapted from https://github.com/dharwath
    
    return np.append(signal[0],signal[1:]-coeff*signal[:-1])

categories_to_ind = {}

for i, cat in enumerate(categories):
    categories_to_ind[cat] = i

audio_conf = args["audio_config"]
target_length = audio_conf.get('target_length', 1024)
padval = audio_conf.get('padval', 0)
image_conf = args["image_config"]
crop_size = image_conf.get('crop_size')
center_crop = image_conf.get('center_crop')
RGB_mean = image_conf.get('RGB_mean')
RGB_std = image_conf.get('RGB_std')

image_resize_and_crop = transforms.Compose(
        [transforms.Resize(224), transforms.ToTensor()])

image_normalize = transforms.Normalize(mean=RGB_mean, std=RGB_std)

image_resize = transforms.transforms.Resize((256, 256))

def myRandomCrop(im1, im2):

    im1 = image_resize_and_crop(im1)
    im2 = image_resize_and_crop(im2)
    return im1, im2

def LoadAudio(path):

    audio_type = audio_conf.get('audio_type')
    if audio_type not in ['melspectrogram', 'spectrogram']:
        raise ValueError('Invalid audio_type specified in audio_conf. Must be one of [melspectrogram, spectrogram]')

    preemph_coef = audio_conf.get('preemph_coef')
    sample_rate = audio_conf.get('sample_rate')
    window_size = audio_conf.get('window_size')
    window_stride = audio_conf.get('window_stride')
    window_type = audio_conf.get('window_type')
    num_mel_bins = audio_conf.get('num_mel_bins')
    target_length = audio_conf.get('target_length')
    fmin = audio_conf.get('fmin')
    n_fft = audio_conf.get('n_fft', int(sample_rate * window_size))
    win_length = int(sample_rate * window_size)
    hop_length = int(sample_rate * window_stride)

    # load audio, subtract DC, preemphasis
    y, sr = librosa.load(path, sample_rate)
    dur = librosa.get_duration(y=y, sr=sr)
    nsamples = y.shape[0]
    if y.size == 0:
        y = np.zeros(target_length)
    y = y - y.mean()
    y = preemphasis(y, preemph_coef)

    # compute mel spectrogram / filterbanks
    stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
        window=scipy_windows.get(window_type, scipy_windows['hamming']))
    spec = np.abs(stft)**2 # Power spectrum
    if audio_type == 'melspectrogram':
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels=num_mel_bins, fmin=fmin)
        melspec = np.dot(mel_basis, spec)
        logspec = librosa.power_to_db(melspec, ref=np.max)
    elif audio_type == 'spectrogram':
        logspec = librosa.power_to_db(spec, ref=np.max)
    # n_frames = logspec.shape[1]
    logspec = torch.FloatTensor(logspec)
    return torch.tensor(logspec), nsamples#, n_frames

def LoadImage(impath, id, imseg):
    img = Image.open(impath).convert('RGB')
    # img = self.image_resize_and_crop(img)
    img = myRandomCrop(img, imseg, id)
    img = image_normalize(img)
    return img

def PadFeat(feat):
    nframes = feat.shape[1]
    pad = target_length - nframes

    if pad > 0:
        feat = np.pad(feat, ((0, 0), (0, pad)), 'constant',
            constant_values=(padval, padval))
    elif pad < 0:
        nframes = target_length
        feat = feat[:, 0: pad]

    return torch.tensor(feat).unsqueeze(0), torch.tensor(nframes).unsqueeze(0)

In [32]:
for i, entry in tqdm(enumerate(samples)):

    gt_trn = [i for i in entry["trn"] if i in vocab]
    target_dur = [(start_end, dur, tok) for (start_end, dur, tok) in entry["dur"] if tok.casefold() in vocab]
    image_fn = image_base / Path('_'.join(str(Path(entry['wave']).stem).split('_')[0:2]) + '.jpg')
    seg_fn = Path('data/flickr_image_masks/') / Path('_'.join(str(Path(entry['wave']).stem).split('_')[0:2]) + '.npz')
    wav_fn = Path('/mnt/HDD/leanne_HDD/Datasets/flickr_audio/wavs') / Path(str(Path(entry['wave']).stem) + '.wav') 
    english_audio_feat, nsamples = LoadAudio(wav_fn)
    english_audio_feat, english_nframes = PadFeat(english_audio_feat)
    seg = np.load(seg_fn)['panoptic_segmentation']
    
    for gt_word in gt_trn:
        gt_id = vocab_ids[gt_word][0]
        image = LoadImage(image_fn, gt_id, seg)
    break

  return torch.tensor(logspec), nsamples#, n_frames
0it [00:00, ?it/s]

[1]





In [18]:

        
        
        

        masked_img_unnorm = torch.cat(masked_img_unnorm, dim=0)
        masked_img = torch.cat(masked_img, dim=0)
        target = torch.cat(target, dim=0)
        english_audio_feat = english_audio_feat.repeat(masked_img.size(0), 1, 1)
        english_nframes = english_nframes.repeat(masked_img.size(0))
        image = image.repeat(masked_img.size(0), 1, 1, 1)
        raw_img = raw_img.repeat(masked_img.size(0), 1, 1, 1)
        eng = [str(english_name) for i in range(masked_img.size(0))]
        
        return image, masked_img, masked_img_unnorm, target, raw_img, english_audio_feat, english_nframes, eng

    def __len__(self):
        return len(self.data)