In [8]:
import glob
from collections import Counter
import pickle
from gtts import gTTS
from pydub import AudioSegment
import resampy
import librosa
import os
import pydub
import numpy as np
import math
from tqdm import tqdm
import torch
sys.path.append('S2U/')
from run_utils import load_audio_model_and_state
from steps.unit_analysis import (get_feats_codes,DenseAlignment)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from dataloaders.utils import compute_spectrogram

In [9]:
# キャプション作成
obj2color = {
    'apple': 'red',
    'banana': 'yellow',
    'carrot': 'orange',
    'cherry': 'black',
    'cucumber': 'green',
    'egg': 'chicken',
    'eggplant': 'purple',
    'green_pepper': 'green',
    'hyacinth_bean': 'green',
    'kiwi_fruit': 'brown',
    'lemon': 'yellow',
    'onion': 'yellow',
    'orange': 'orange',
    'potato': 'brown',
    'sliced_bread': 'yellow',
    'small_cabbage': 'green',
    'strawberry': 'red',
    'sweet_potato': 'brown',
    'tomato': 'red',
    'white_radish': 'white',
}

def add_indications(name, idx):
    preposition = 'an' if name[0] in ['a', 'o', 'e'] else 'a'
    color = obj2color[name]
    preposition2 = 'an' if color[0] in ['a', 'o', 'e'] else 'a'
    ll = [name, f'{preposition} {name}',
          f'{preposition2} {color} {name}', f"it's {preposition} {name}"]
    return ll[idx].replace("_", " ")

hop_sec = 0.010
win_len_sec = 0.025

def write_wav(f, sr, x, normalized=False):
    y = np.int16(x)
    channels = 2 if (x.ndim == 2 and x.shape[1] == 2) else 1
    song = pydub.AudioSegment(
        y.tobytes(), frame_rate=sr, sample_width=y.dtype.itemsize, channels=channels)
    song.export(f, format="wav")

def add_noise(y_ori, sr_ori, SNR=20):
    mean = 0
    var = 1
    sigma = var**0.5
    y, sr = y_ori.copy(), sr_ori
    len_seg = sr // 3  # 20 ms
    len_y = len(y)
    M = math.ceil(len_y / len_seg)
    noise = np.random.normal(mean, sigma, len(y))
    for m in range(M - 1):
        start = m * len_seg
        end = min(len_y, (m + 1) * len_seg)
        # Avoid too short segments.
        if len_y - end > 0 and len_y - end < len_seg / 2:
            end = len_y
        y_seg = np.array(y[start:end], dtype='float64')
        n_seg = np.array(noise[start:end], dtype='float64')
        sum_s = np.sum(y_seg ** 2)
        sum_n = np.sum(n_seg ** 2)
        w = np.sqrt(sum_s / (sum_n * pow(10, SNR / 10)))  # SNR: 30db
        #print(sum_s, np.sqrt(sum_s/(sum_n * pow(10, -10 / 10))))
        n_seg = w * n_seg
        noise[start:end] = n_seg
        y[start:end] = y_seg
    noisy = noise + y
    snr = 10 * np.log10(np.sum(y ** 2) / np.sum(noise ** 2))
    return sr, noisy

def get_code_ali(audio_model, layer, path, device):
    y, sr = librosa.load(path, None)
    mels, nframes = compute_spectrogram(y, sr)
    # print("nframes",nframes)
    mels = mels[:, :nframes]
    _, _, codes, spf = get_feats_codes(audio_model, layer, mels, device)
    code_list = codes.detach().cpu().tolist()
    return DenseAlignment(code_list, spf)

# load model
exp_dir = 'S2U/RDVQ_01000_01100_01110'
audio_model = load_audio_model_and_state(exp_dir=exp_dir)
audio_model = audio_model.to(device)
audio_model = audio_model.eval()

In [19]:
# データをコピーする
# まずはpathを全て集める
# それを3つに分けてlistに入れる
train_image_paths = glob.glob('dataset/*/*/train_number[12]/*.jpg')
val_image_paths = glob.glob('dataset/*/*/train_number3/*.jpg')
test_image_paths = glob.glob('dataset/*/*/test*/*.jpg')


# captionを作成する
# captionの昔のコードを見てみる


# 画像のパスをとってくる
# それぞれのパスの文字をとってきた時に、とってきた文字に対してT2SとS2Uを実行して保存する.
# キャプションに対して音声を作成してユニットにする関数を作成する。
# 最初にmp3で作るもの

# テキストから音声を作成する
def make_mp3_wav(food_name,text,i):
    os.makedirs(f"audio/{food_name}", exist_ok=True)
    gTTS(text=text, lang='en').save(f'audio/tmp.mp3')
    AudioSegment.from_mp3('audio/tmp.mp3').export(f'audio/{food_name}/{food_name}_{i}.wav', format="wav")


def make_noise_audio(food_name,i,j):
    a=pydub.AudioSegment.from_mp3(f"audio/{food_name}/{food_name}_{i}.wav")
    sr_ori=a.frame_rate
    y_ori=np.array(a.get_array_of_samples(), dtype="float64")
    snr=20
    sr, noisy = add_noise(y_ori, sr_ori, snr)
    path = f"audio/{food_name}/{food_name}_{i}_{j}.wav"
    write_wav(path, sr, noisy)
    return path


def get_unit(path):
    code_q3_ali = get_code_ali(audio_model, 'quant3', path, device).get_sparse_ali()
    encoded = []
    enc_old = -1
    for _,_,code in code_q3_ali.data:
        if enc_old == code:
            print("same")
            same+=1
        else:
            encoded.append(str(code))
        enc_old = code
    return encoded

#captionは、food_nameとcaption_typeからわかる
food_name_old = "init"
train_image_captions = []
for j, path in enumerate(tqdm(train_image_paths)):
    captions = []
    food_name = path.split("/")[2]
    for i in range(4):
        if food_name != food_name_old:
            text = add_indications(food_name,i)
            # print(text,type(text))
            make_mp3_wav(food_name,text,i)
        captions.append(get_unit(make_noise_audio(food_name,i,j%60)))
        # captions.append(add_indications(path.split("/")[2],i).split())
    food_name_old = food_name
    train_image_captions.append(captions)


food_name_old = "init"
val_image_captions = []
for j, path in enumerate(tqdm(val_image_paths)):
    captions = []
    food_name = path.split("/")[2]
    for i in range(4):
        if food_name != food_name_old:
            text = add_indications(food_name,i)
            # print(text,type(text))
            make_mp3_wav(food_name,text,i)
        captions.append(get_unit(make_noise_audio(food_name,i,j%30+60)))
        # captions.append(add_indications(path.split("/")[2],i).split())
    food_name_old = food_name
    val_image_captions.append(captions)


food_name_old = "init"
test_image_captions = []
for j, path in enumerate(tqdm(test_image_paths)):
    captions = []
    food_name = path.split("/")[2]
    for i in range(4):
        if food_name != food_name_old:
            text = add_indications(food_name,i)
            # print(text,type(text))
            make_mp3_wav(food_name,text,i)
        captions.append(get_unit(make_noise_audio(food_name,i,j%30+90)))
        # captions.append(add_indications(path.split("/")[2],i).split())
    food_name_old = food_name
    test_image_captions.append(captions)



# val_image_captions = []
# for path in val_image_paths:
#     captions = []
#     for j in range(4):
#         captions.append(add_indications(path.split("/")[2],j).split())
#     val_image_captions.append(captions)

# test_image_captions = []
# for path in test_image_paths:
#     captions = []
#     for j in range(4):
#         captions.append(add_indications(path.split("/")[2],j).split())
#     test_image_captions.append(captions)


# captionを使って、counterを作成する
# counterのコードを見る
# word_freq
word_freq = Counter()
for captions in train_image_captions:
    for caption in captions:
        word_freq.update(caption)

#
with open('train_image_paths.pickle', 'wb') as f:
    pickle.dump(train_image_paths, f)
with open('train_image_captions.pickle', 'wb') as f:
    pickle.dump(train_image_captions, f)
with open('val_image_paths.pickle', 'wb') as f:
    pickle.dump(val_image_paths, f)
with open('val_image_captions.pickle', 'wb') as f:
    pickle.dump(val_image_captions, f)
with open('test_image_paths.pickle', 'wb') as f:
    pickle.dump(test_image_paths, f)
with open('test_image_captions.pickle', 'wb') as f:
    pickle.dump(test_image_captions, f)
with open('word_freq.pickle', 'wb') as f:
    pickle.dump(word_freq, f)




100%|██████████| 1200/1200 [03:51<00:00,  5.18it/s]
100%|██████████| 600/600 [02:12<00:00,  4.51it/s]
100%|██████████| 600/600 [02:12<00:00,  4.52it/s]


In [6]:
test_image_paths

['dataset/LAbyRL/apple/test_number3/group5_2.jpg',
 'dataset/LAbyRL/apple/test_number3/group4_1.jpg',
 'dataset/LAbyRL/apple/test_number3/group2_1.jpg',
 'dataset/LAbyRL/apple/test_number3/group3_1.jpg',
 'dataset/LAbyRL/apple/test_number3/group5_1.jpg',
 'dataset/LAbyRL/apple/test_number3/group1_1.jpg',
 'dataset/LAbyRL/apple/test_number3/group2_2.jpg',
 'dataset/LAbyRL/apple/test_number3/group4_2.jpg',
 'dataset/LAbyRL/apple/test_number3/group1_2.jpg',
 'dataset/LAbyRL/apple/test_number3/group3_2.jpg',
 'dataset/LAbyRL/apple/test_number1/group5_2.jpg',
 'dataset/LAbyRL/apple/test_number1/group4_1.jpg',
 'dataset/LAbyRL/apple/test_number1/group2_1.jpg',
 'dataset/LAbyRL/apple/test_number1/group3_1.jpg',
 'dataset/LAbyRL/apple/test_number1/group5_1.jpg',
 'dataset/LAbyRL/apple/test_number1/group1_1.jpg',
 'dataset/LAbyRL/apple/test_number1/group2_2.jpg',
 'dataset/LAbyRL/apple/test_number1/group4_2.jpg',
 'dataset/LAbyRL/apple/test_number1/group1_2.jpg',
 'dataset/LAbyRL/apple/test_num

In [29]:
# word_freq = Counter()
# word_freq.update(['an', 'apple'])
print(word_freq)

Counter({'a': 2880, "it's": 1200, 'an': 720, 'green': 480, 'potato': 480, 'orange': 360, 'white': 300, 'apple': 240, 'banana': 240, 'yellow': 240, 'carrot': 240, 'cherry': 240, 'cucumber': 240, 'egg': 240, 'eggplant': 240, 'pepper': 240, 'hyacinth': 240, 'bean': 240, 'kiwi': 240, 'fruit': 240, 'lemon': 240, 'onion': 240, 'sliced': 240, 'bread': 240, 'small': 240, 'cabbage': 240, 'strawberry': 240, 'sweet': 240, 'tomato': 240, 'radish': 240, 'red': 180, 'brown': 180, 'black': 60, 'chicken': 60, 'purple': 60})


In [24]:
train_image_paths
len(train_image_paths)

1200

In [18]:
for i in range(4):
    print(add_indications("orange",i))

orange
an orange
an orange orange
it's an orange


In [12]:
test_image_paths[0].split("/")[2]

'apple'

In [21]:
train_image_captions

[[['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'], ["it's", 'an', 'apple']],
 [['apple'], ['an', 'apple'], ['a', 'red', 'apple'],

In [18]:
len(word_freq)

570