In [1]:
import datetime
from glob import glob
import json
import os
import sys
import yaml
from tqdm import tqdm
import argparse

import numpy as np
import h5py
from PIL import Image
import resampy
import torch
from torchvision import transforms
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

sys.path.append("../..")
import hifigan
from hifigan.env import AttrDict
from hifigan.models import Generator

# sys.path.append("../U2S")
# # import U2S
# from hparams import create_hparams
# from train import load_model
# from text import text_to_sequence

sys.path.append("./models")
# from models import models_modified
from models import TransformerConditionedLM
from models_modified import TransformerSentenceLM_FixedImg, TransformerSentenceLM_FixedImg_gated

2023-03-23 13:47:40.388109: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-23 13:47:42.263445: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-23 13:47:42.263671: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
sys.path.append('../')
from gslm.unit2speech.tts_data import (
    TacotronInputDataset,
)
from gslm.unit2speech.utils import (
    load_quantized_audio_from_file,
    load_tacotron,
    load_waveglow,
    synthesize_audio,
)

In [3]:
model_path = "../../saved_model/I2U/VC_5_captions_224_hubert/23-03-20_19:17:08_uLM_sentence"
word_map_path="../../data/processed/VC_5_captions_224_hubert/WORDMAP_coco_5_cap_per_img_1_min_word_freq.json"

In [5]:
with open('../../config.yml') as yml:
        config = yaml.safe_load(yml)
    
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------------------------------------------------------------

# U2S
# /net/papilio/storage2/yhaoyuan/LAbyLM/dataprep/RL/image2speech_inference.ipynb

# tacotron2
# hparams = create_hparams()
# hparams.sampling_rate = 22050
# checkpoint_path = config["u2s"]["tacotron2"]
# tacotron2_model = load_model(hparams)
# tacotron2_model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
# tacotron2_model.cuda().eval()

# --------------------------------------------------------------------------------

# HiFi-GAN
# /net/papilio/storage2/yhaoyuan/LAbyLM/dataprep/RL/image2speech_inference.ipynb

checkpoint_file = config["u2s"]['hifigan']
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
with open(config_file) as f:
        data = f.read()

global h
json_config = json.loads(data)
h = AttrDict(json_config)
generator = Generator(h).to(device)
assert os.path.isfile(checkpoint_file)
checkpoint_dict = torch.load(checkpoint_file, map_location=device)
generator.load_state_dict(checkpoint_dict['generator'])
generator.eval()
generator.remove_weight_norm()

# --------------------------------------------------------------------------------

# S2T
asr_path = config["asr"]["model_path"]
asr_path_ljs = "/net/papilio/storage2/yhaoyuan/transformer_I2S/saved_model/ASR/wav2vec2-base-tuned-ljs/checkpoint-4000"

def get_asr(path):
    processor = Wav2Vec2Processor.from_pretrained(path)
    asr_model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
    asr_model.eval()
    return asr_model, processor

asr_model, processor = get_asr(asr_path)
asr_model_ljs, processor_ljs = get_asr(asr_path_ljs)

def get_trans(audio, asr_model, processor):
    audio = audio.astype(np.float64)
    input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.float()
    logits = asr_model(input_values.to(device)).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    return transcription

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Removing weight norm...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
tts_model_path = "/net/papilio/storage2/yhaoyuan/transformer_I2S/gslm_models/u2S/HuBERT_KM100_tts_checkpoint_best.pt"
max_decoder_steps = 500
code_dict_path = "/net/papilio/storage2/yhaoyuan/transformer_I2S/gslm_models/u2S/HuBERT_KM100_code_dict"

tacotron_model, sample_rate, hparams = load_tacotron(
    tacotron_model_path=tts_model_path,
    max_decoder_steps=max_decoder_steps,
)

# waveglow, denoiser = load_waveglow(waveglow_path=waveglow_path)

if not os.path.exists(hparams.code_dict):
    hparams.code_dict = code_dict_path
tts_dataset = TacotronInputDataset(hparams)

In [7]:
config_path = glob(model_path + "/config*.yml")[0]
# config_path = glob(model_path+"/*")
model_checkpoint = glob(model_path+"/*BEST*.tar")[0]
# word_map_path="../../data/processed/SpokenCOCO_LibriSpeech/WORDMAP_coco_1_cap_per_img_1_min_word_freq.json"

# Load word map (word2ix)
global word_map, rev_word_map, special_words
with open(word_map_path) as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word
special_words = {"<unk>", "<start>", "<end>", "<pad>"}

with open(config_path, 'r') as yml:
    model_config = yaml.safe_load(yml)

model_params = model_config["i2u"]["model_params"]
model_params['vocab_size'] = len(word_map)
model_params['refine_encoder_params'] = model_config["i2u"]["refine_encoder_params"]

In [9]:
from models import PositionalEncoding

In [10]:
model = TransformerSentenceLM_FixedImg(**model_params)
model.pos_encoder = PositionalEncoding(d_model=1024, max_len=152)
model.load_state_dict(torch.load(model_checkpoint)["model_state_dict"])
model.to(device)

TransformerSentenceLM_FixedImg(
  (embed): Embedding(104, 1024)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (LM_decoder): None
  (classifier): Linear(in_features=1024, out_features=104, bias=True)
  (image_encoder): DinoResEncoder_Raw(
    (resnet): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [16]:
def load_images(data_path, split):
    image_hdf5 = glob(data_path+f"/{split}*.hdf5")[0]
    image_names = glob(data_path+f"/{split}*.json")[0]
    h = h5py.File(image_hdf5, 'r')
    images = h['images']
    with open(image_names, "r") as f:
        names = json.load(f)
    return images, names

def get_transformed_img(img, transform): # -> torch.tensor
    # img = img.transpose(2, 0, 1) # (224, 224, 3) -> (3, 224, 224)
    img = torch.FloatTensor(img / 255.)
    if transform is not None:
        img = transform(img)
    return img.to(device)

def synthesize_mel(model, inp, lab=None, strength=0.0):
    assert inp.size(0) == 1
    inp = inp.cuda()
    if lab is not None:
        lab = torch.LongTensor(1).cuda().fill_(lab)

    with torch.no_grad():
        _, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True)
    return mel, has_eos

def u2s(seqs):
    words = [rev_word_map[ind] for ind in seqs if rev_word_map[ind] not in special_words]
    # print(words)
    quantized_units_str = " ".join(words)
    tts_input = tts_dataset.get_tensor(quantized_units_str)
    mel, has_eos = synthesize_mel(
        tacotron_model,
        tts_input.unsqueeze(0),
    )
    with torch.no_grad():
        x = mel.squeeze().float()
        # x = torch.FloatTensor(x).to(device)
        y_g_hat = generator(x)
        audio = y_g_hat.squeeze()
        audio = audio * 32768.0
        audio = audio.cpu().numpy().astype('int16')
        # import IPython.display as ipd
        # display(ipd.Audio(audio, rate=22050))
    # audio = audio.astype(np.float64)
    # input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.float()
    # logits = asr_model(input_values.to(device)).logits
    # predicted_ids = torch.argmax(logits, dim=-1)
    # transcription = processor.decode(predicted_ids[0])
    trans = get_trans(audio, asr_model, processor)
    trans_ljs = get_trans(audio, asr_model_ljs, processor_ljs)
    print(f"Trans by tuned asr (on food data): {trans}")
    print(f"Trans by tuned asr (on LJSpeech): {trans_ljs}")

In [14]:
image_data_path = "../../data/RL/224"
split = "TEST"
imgs, names = load_images(image_data_path, split)
names = names[:10]
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

In [17]:
# for i in tqdm(range(len(names)), desc=f"Getting {split} Results"):
for i in range(len(names)):
    img = get_transformed_img(imgs[i], transform)
    # img = get_transformed_img(imgs[i], transform=None)
    img = img.unsqueeze(0)
    name = names[i]
    seqs = model.decode(image=img, start_unit=word_map["<start>"], end_unit=word_map["<end>"], max_len=150, beam_size=10)
    print(name)
    u2s(seqs)
    print("\n")

eggplant_wh2_10.jpg
Trans by tuned asr (on food data): there is one ppceadto in a white background
Trans by tuned asr (on LJSpeech): there is one eaklandzs in a white background


carrot_br3_20.jpg
Trans by tuned asr (on food data): there ar twocarrots in a brown backgrole und
Trans by tuned asr (on LJSpeech): through ouht two cowrds in a brown background


pea_bl2_11.jpg
Trans by tuned asr (on food data): there are two reen peas in a blue background
Trans by tuned asr (on LJSpeech): therein to cren pees in a giny backround


eggplant_br3_13.jpg
Trans by tuned asr (on food data): there are three eggplants in a brown background
Trans by tuned asr (on LJSpeech): there ar three eighkt plants in a bronbed grown


kiwi_br3_19.jpg
Trans by tuned asr (on food data): there arie three kioicslus in a brown backgroere arhie thund
Trans by tuned asr (on LJSpeech): therrat wiche eou leflins in a brow background yerroth


bread_bl2_19.jpg
Trans by tuned asr (on food data): there are two sliced bread