In [1]:
import glob
from collections import Counter
import pickle
from gtts import gTTS
from pydub import AudioSegment
import librosa
import os
import pydub
import numpy as np
import math
from tqdm import tqdm
import torch
import json
import sys
sys.path.append('../../egs/S2U/')
from run_utils import load_audio_model_and_state
from steps.unit_analysis import (get_feats_codes,DenseAlignment)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from dataloaders.utils import compute_spectrogram

import yaml

with open('../../config.yml', 'r') as yml:
    config = yaml.safe_load(yml)

exp_dir = "../../"+config["s2u"]["model"]
audio_model = load_audio_model_and_state(exp_dir=exp_dir)
audio_model = audio_model.to(device)
audio_model = audio_model.eval()

In [2]:
train_json='/net/papilio/storage2/yhaoyuan/LAbyLM/new_dataset/train_synthesized_gtts_color.json'
valid_json='/net/papilio/storage2/yhaoyuan/LAbyLM/new_dataset/valid_synthesized_gtts_color.json'
test_json='/net/papilio/storage2/yhaoyuan/LAbyLM/new_dataset/test_synthesized_gtts_color.json'

with open(train_json,'r') as train_f:
    train_pair_datapath=json.load(train_f)

with open(valid_json,'r') as valid_f:
    valid_pair_datapath=json.load(valid_f)

with open(test_json,'r') as test_f:
    test_pair_datapath=json.load(test_f)

In [1]:
leave_out={
    'turnip': 2,
    'apple': 2,
    'banana': 2,
    'greenpepper': 3,
    'sweetpotato': 3,
    'tomato': 2,
    'potatoes': 3,
    'bread': 1,
    'carrot': 3,
    'orange': 1,
    'onion': 1,
    'avocado': 3,
    'egg': 3,
    'eggplant': 3,
    'strawberry': 1,
    'kiwi': 1,
    'pea': 1,
    'cucumber': 1,
    'grape': 2,
    'lemon': 3
}

In [3]:

# Load S2U model for captions

def get_code_ali(audio_model, layer, path, device):
    y, sr = librosa.load(path, None)
    mels, nframes = compute_spectrogram(y, sr)
    # print("nframes",nframes)
    mels = mels[:, :nframes]
    _, _, codes, spf = get_feats_codes(audio_model, layer, mels, device)
    code_list = codes.detach().cpu().tolist()
    return DenseAlignment(code_list, spf)

def get_unit(path):
    code_q3_ali = get_code_ali(audio_model, 'quant3', path, device).get_sparse_ali()
    encoded = []
    enc_old = -1
    for _,_,code in code_q3_ali.data:
        if enc_old == code:
            print("same")
            same+=1
        else:
            encoded.append(str(code))
        enc_old = code
    return encoded

def get_image_info(image_path):
    info=image_path.split("/")[8]
    name=info.split("_")[0]
    if name=="orange":
        color=info.split("_")[1]
        number=info.split("_")[2]
    else:
        color=info.split("_")[1][0:-1]
        number=info.split("_")[1][-1]
    return name, color, number

In [4]:
#text_csv="/net/papilio/storage2/yhaoyuan/LAbyLM/new_dataset/record.csv"
#image_source_path="/net/tateha/storage2/database/spolacq/FoodImagesA"
#audio_source_path="/net/papilio/storage2/yhaoyuan/LAbyLM/audios_synthesized_gtts"
audio_source_path=train_pair_datapath["audio_base_path"]
# get image paths:
# It's important to keep the previous split, so choose from previous json
train_image_paths=[]
valid_image_paths=[]
test_image_paths=[]

def get_image_paths_leave_out(pair_datapath):
    image_paths=[]
    for data in pair_datapath['data']:
        image=data['image']
        full_path=os.path.join(pair_datapath['image_base_path'],image)
        # Skip those images in leave_out set.

        name,color,number=get_image_info(full_path)
        if int(number)==int(leave_out[name]):
            #print("In leave out")
            continue

        #sanity check
        if not os.path.isfile(full_path):
            raise ValueError ("File Not Exists")
        if full_path not in image_paths:
            image_paths.append(full_path)
    return image_paths, len(image_paths)

def get_image_paths(pair_datapath):
    image_paths=[]
    for data in pair_datapath['data']:
        image=data['image']
        full_path=os.path.join(pair_datapath['image_base_path'],image)
        #sanity check
        if not os.path.isfile(full_path):
            raise ValueError ("File Not Exists")
        if full_path not in image_paths:
            image_paths.append(full_path)
    return image_paths, len(image_paths)

#train_image_paths, _ = get_image_paths_leave_out(train_pair_datapath)
train_image_paths, _ = get_image_paths(train_pair_datapath)
valid_image_paths, _ = get_image_paths(valid_pair_datapath)
test_image_paths, _ = get_image_paths(test_pair_datapath)

In [9]:
for i in range(10):
    print(train_image_paths[i])

/net/tateha/storage2/database/spolacq/FoodImagesA/11_avocado/avocado_bl3_01.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/01_turnip/turnip_br2_10.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/14_apple/apple_wh3_10.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/03_banana/banana_wh1_20.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/14_apple/apple_wh1_14.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/02_sweetpotato/sweetpotato_bl2_22.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/01_turnip/turnip_bl2_07.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/07_tomato/tomato_br1_11.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/16_bread/bread_bl2_17.jpg
/net/tateha/storage2/database/spolacq/FoodImagesA/06_greenpepper/greenpepper_bl3_11.jpg


In [7]:
# Speaker's split
speaker_division_path='/net/papilio/storage2/yhaoyuan/LAbyLM/new_dataset/202202-rec-summary-redivide.csv'
speaker_division={}
speaker_division["train"]=[]
speaker_division["valid"]=[]
speaker_division["test"]=[]
with open(speaker_division_path,"r") as f:
    for line in f.readlines():
        speaker_ID=line.strip("\n").split(",")[0]
        mode=line.strip("\n").split(",")[3]
        if mode=="train":
            speaker_division["train"].append(speaker_ID)
        if mode=="validation":
            speaker_division["valid"].append(speaker_ID)
        if mode=="test":
            speaker_division["test"].append(speaker_ID)

# food_list contains 'foodname: sub_dir' pairs
food_list={
    'apple': '/FIX/14_apple',
    'banana': '/FIX/03_banana',
    'carrot': '/FIX/04_carrot',
    'grape': '/FIX/18_grape',
    'cucumber': '/FIX/15_cucumber',
    'egg': '/FIX/08_egg',
    'eggplant': '/FIX/05_eggplant',
    'greenpepper': '/FIX/06_greenpepper',
    'pea': '/FIX/17_pea',
    'kiwi': '/FIX/10_kiwi',
    'lemon': '/FIX/09_lemon',
    'onion': '/FIX/13_onion',
    'orange': '/orange',
    'potatoes': '/FIX/12_potatoes',
    'bread': '/FIX/16_bread',
    'avocado': '/FIX/11_avocado',
    'strawberry': '/FIX/19_strawberry',
    'sweetpotato': '/FIX/02_sweetpotato',
    'tomato': '/FIX/07_tomato',
    'turnip': '/FIX/01_turnip'
    #'orange02': '/orange02'
}

# food_id contains 'foodname: pos' pairs
food_id={}
i=1
for key in food_list:
    food_id[key]=i
    i+=1

colors=['bl','br','wh']
numbers=['1','2','3']

from random import choice

def find_audio(food_name, color, number,mode_1,mode_2='unclean',form='wav',image_path=None):
    basic_number=20*(food_id[food_name]-1)+2+(eval(number)-1)*6
    # id_1 and id_2 describes all information
    # id_3 only gives name with the right form
    if color=='wh':
        id_1=str(basic_number+1)
        id_2=str(basic_number+2)
    elif color=='bl':
        id_1=str(basic_number+3)
        id_2=str(basic_number+4)
    elif color=='br':
        id_1=str(basic_number+5)
        id_2=str(basic_number+6)
    else:
        print(color)
        print(image_path)
        raise ValueError("Not recognized color")
    if number=='1':
        id_3=str(20*(food_id[food_name]-1)+1)
    else:
        id_3=str(20*(food_id[food_name]-1)+2)
    if mode_1 in ["train","valid","test"]:
        caption_1=audio_source_path+'/'+str(choice(speaker_division[mode_1]))+'/'+id_1+'.'+form
        caption_2=audio_source_path+'/'+str(choice(speaker_division[mode_1]))+'/'+id_2+'.'+form
        caption_3=audio_source_path+'/'+str(choice(speaker_division[mode_1]))+'/'+id_3+'.'+form
    else:
        raise ValueError("Mode not recognized. Please choose from: train, valid, test.")
    #description_1=glob.glob(audio_source_path+'/*/'+id_1+'.'+form)
    #description_2=glob.glob(audio_source_path+'/*/'+id_2+'.'+form)
    #description_3=glob.glob(audio_source_path+'/*/'+id_3+'.'+form)
    caption_1=get_unit(caption_1)
    caption_2=get_unit(caption_2)
    caption_3=get_unit(caption_3)
    if mode_2=='clean':
        return caption_1, caption_2
    if mode_2=='unclean':
        return caption_3, caption_1, caption_2


def get_image_captions(image_paths, mode_1, mode_2):
    all_captions=[] # to store all captions for all images
    for single_image_path in tqdm(image_paths):
        captions=[] # to store multiple captions for single image
        if mode_2=="clean": # 2 captions
            name, color, number = get_image_info(single_image_path)
            caption_1, caption_2 = find_audio(name, color, number,mode_1,mode_2,form='wav',image_path=single_image_path)
            captions.append(caption_1)
            captions.append(caption_2)
        if mode_2=="unclean": # 3 captions
            name, color, number = get_image_info(single_image_path)
            caption_3, caption_1, caption_2 = find_audio(name, color, number,mode_1,mode_2,form='wav',image_path=single_image_path)
            captions.append(caption_3)
            captions.append(caption_1)
            captions.append(caption_2)
        all_captions.append(captions)
    return all_captions

train_image_captions=get_image_captions(train_image_paths, "train", "clean")
valid_image_captions=get_image_captions(valid_image_paths, "train", "clean")
test_image_captions=get_image_captions(test_image_paths, "train", "clean")

  4%|▍         | 95/2264 [00:21<08:12,  4.40it/s] 


KeyboardInterrupt: 

In [125]:
valid_image_captions

[[['865',
   '532',
   '93',
   '475',
   '439',
   '602',
   '99',
   '1015',
   '881',
   '955',
   '112',
   '309',
   '31',
   '853',
   '596',
   '44',
   '699',
   '162',
   '845',
   '181',
   '860',
   '1018',
   '242',
   '755',
   '134',
   '579',
   '410',
   '71',
   '721',
   '479',
   '768',
   '803'],
  ['808',
   '626',
   '680',
   '279',
   '682',
   '310',
   '165',
   '340',
   '37',
   '245',
   '681',
   '701',
   '875',
   '548',
   '522',
   '389',
   '595',
   '702',
   '112',
   '309',
   '599',
   '853',
   '180',
   '226',
   '699',
   '845',
   '860',
   '1018',
   '977',
   '292',
   '792',
   '449',
   '661',
   '896',
   '213',
   '343',
   '984',
   '834',
   '261',
   '336',
   '207',
   '793',
   '100',
   '330',
   '1016',
   '414',
   '394',
   '287',
   '375',
   '35',
   '636',
   '384',
   '985',
   '909',
   '810',
   '681',
   '777',
   '271',
   '566',
   '388',
   '582',
   '578',
   '650',
   '410',
   '71',
   '479',
   '229',
   '479',
   

In [116]:
with open("/net/papilio/storage2/yhaoyuan/KIMURA/data/I2U/processed/test_image_captions.pickle", 'rb') as f:
    test_data=pickle.load(f)

In [117]:
test_data

[[['844',
   '865',
   '943',
   '731',
   '747',
   '475',
   '630',
   '193',
   '424',
   '278',
   '496',
   '586',
   '646',
   '733',
   '650',
   '410',
   '865'],
  ['803',
   '844',
   '33',
   '792',
   '131',
   '617',
   '276',
   '794',
   '394',
   '747',
   '439',
   '630',
   '205',
   '424',
   '278',
   '496',
   '586',
   '934',
   '42',
   '583',
   '479',
   '803'],
  ['808',
   '865',
   '10',
   '323',
   '610',
   '574',
   '925',
   '666',
   '908',
   '670',
   '963',
   '930',
   '907',
   '414',
   '747',
   '439',
   '602',
   '851',
   '501',
   '462',
   '458',
   '586',
   '934',
   '42',
   '583',
   '479',
   '768',
   '803'],
  ['865',
   '867',
   '922',
   '0',
   '155',
   '346',
   '164',
   '848',
   '356',
   '390',
   '794',
   '394',
   '747',
   '439',
   '630',
   '205',
   '424',
   '278',
   '496',
   '947',
   '709',
   '733',
   '650',
   '299',
   '865']],
 [['844',
   '865',
   '943',
   '731',
   '747',
   '475',
   '630',
   '371',
 