In [44]:
import datetime
from glob import glob
import json
import os
import sys

import gym
import numpy as np
from PIL import Image
import resampy
from stable_baselines3.common.noise import NormalActionNoise
import torch
from torchvision import transforms
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import yaml

from custom_policy2 import CustomTD3PolicyCNN, CustomDDPG


In [45]:
sys.path.append("../..")
import hifigan
from hifigan.env import AttrDict
from hifigan.models import Generator

sys.path.append("../U2S")
from hparams import create_hparams
from train import load_model
from text import text_to_sequence


sys.path.append("../I2U")
from models import TransformerSentenceLM, TransformerConditionedLM

# config path需要更改
with open('../../config.yml') as yml:
    config = yaml.safe_load(yml)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [46]:
# I2U
word_map_path=config["i2u"]["wordmap"]
# Load word map (word2ix)
with open(word_map_path) as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word
special_words = {"<unk>", "<start>", "<end>", "<pad>"}

# I2U
# checkpoint_path = "../../saved_model/I2U/VC_5_captions/Trial_1/"
checkpoint_path = "../../saved_model/I2U/VC_5_captions/7*7_img_1024*16*12_99accuracy/"

with open(checkpoint_path + 'config.yml', 'r') as yml:
    model_config = yaml.safe_load(yml)

checkpoint = checkpoint_path + f'bleu-4_BEST_checkpoint_coco_{str(model_config["i2u"]["captions_per_image"])}_cap_per_img_{str(model_config["i2u"]["min_word_freq"])}_min_word_freq_gpu.pth.tar'
dir_name = model_config["i2u"]["dir_name"]
model_params = model_config["i2u"]["model_params"]
train_params = model_config["i2u"]["train_params"]
img_refine_params = model_config["i2u"]["refine_encoder_params"]

data_folder = f'../../data/processed/{dir_name}/'  # folder with data files saved by create_input_files.py
# data_name = 'coco_4_cap_per_img_5_min_word_freq'  # base name shared by data files
#data_name = f'coco_{str(config["i2u"]["captions_per_image"])}_cap_per_img_{str(config["i2u"]["min_word_freq"])}_min_word_freq'  # base name shared by data files
data_name = f'coco_{str(model_config["i2u"]["captions_per_image"])}_cap_per_img_{str(model_config["i2u"]["min_word_freq"])}_min_word_freq'  # base name shared by data files
model_params['vocab_size'] = len(word_map)
model_params['refine_encoder_params'] = img_refine_params
model = TransformerConditionedLM(**model_params)
# optimizer = getattr(torch.optim, train_params["optimizer"])(model.parameters(), lr=train_params["lr"])
# model, optimizer, start_epoch, best_bleu4, best_accuracy = load_checkpoint(checkpoint, model, optimizer, device)
model.load_state_dict(torch.load(checkpoint)["model_state_dict"])
model.eval()
model.to(device)

# model_path = config["i2u"]["model"]
# model_config = model_path[:-len(model_path.split("/")[-1])] + "config.yml"
# with open(model_config) as yml:
#     model_config = yaml.safe_load(yml)
# model_params = model_config["i2u"]["sentence_model_params"]
# model_params['vocab_size'] = len(word_map)
# sentence_encoder = TransformerSentenceLM(**model_params)
# trained_model = torch.load(model_path)
# state_dict = trained_model["model_state_dict"]
# sentence_encoder.load_state_dict(state_dict)
# sentence_encoder.eval()
# sentence_encoder.to(device)

# --------------------------------------------------------------------------------

# U2S
# /net/papilio/storage2/yhaoyuan/LAbyLM/dataprep/RL/image2speech_inference.ipynb

# tacotron2
hparams = create_hparams()
hparams.sampling_rate = 22050
checkpoint_path = config["u2s"]["tacotron2"]
tacotron2_model = load_model(hparams)
tacotron2_model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
tacotron2_model.cuda().eval()

# --------------------------------------------------------------------------------

# HiFi-GAN
# /net/papilio/storage2/yhaoyuan/LAbyLM/dataprep/RL/image2speech_inference.ipynb

checkpoint_file = config["u2s"]['hifigan']
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
with open(config_file) as f:
        data = f.read()

global h
json_config = json.loads(data)
h = AttrDict(json_config)
generator = Generator(h).to(device)
assert os.path.isfile(checkpoint_file)
checkpoint_dict = torch.load(checkpoint_file, map_location=device)
generator.load_state_dict(checkpoint_dict['generator'])
generator.eval()
generator.remove_weight_norm()

# --------------------------------------------------------------------------------

# S2T
processor = Wav2Vec2Processor.from_pretrained(config["asr"]["model_path"])
asr_model = Wav2Vec2ForCTC.from_pretrained(config["asr"]["model_path"]).to(device)

# --------------------------------------------------------------------------------

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Removing weight norm...


In [47]:
sys.path.append("..")
from I2U.datasets import *

In [48]:
data_folder = f'../../data/processed/VC_5_captions/'  # folder with data files saved by create_input_files.py
data_name = f'coco_{str(config["i2u"]["captions_per_image"])}_cap_per_img_{str(config["i2u"]["min_word_freq"])}_min_word_freq'  # base name shared by data files


In [49]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
CaptionDataset_transformer(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
batch_size=1, shuffle=True, num_workers=10, pin_memory=True)

In [72]:
for i, (imgs, seq, caplens, seq_padding_mask, _, _) in enumerate(val_loader):
    imgs = imgs.to(device)
    seq = seq.to(device)
    caplens = caplens.to(device)
    caplens = caplens.squeeze()
    seq_padding_mask = seq_padding_mask.to(device)

    seq_len = caplens

    # x = sentence_encoder.embed(seq)
    # x = sentence_encoder.pos_encoder(x)
    # z = sentence_encoder.sentence_encoder(x, src_key_padding_mask = seq_padding_mask)
    # z = z * seq_padding_mask.logical_not().unsqueeze(2)
    # z = z.sum(dim = 1)/ seq_len
    # mu = sentence_encoder.mu(z)  # (batch, sentence_embed)

    # imgs, gx = sentence_encoder.image_encoder(imgs)
    break

In [76]:
imgs.all()>0

tensor(True, device='cuda:0')

In [61]:
seq_gt = [int(unit) for unit in seq.squeeze(0) if unit != 0]

In [62]:
img, gx = model.image_encoder(imgs)
action = img.flatten().unsqueeze(0)

In [63]:
# seq = model.decode(img=imgs, start_unit=word_map["<start>"], end_unit=word_map["<end>"], max_len=130, beam_size=10)
seq = model.decode(start_unit=word_map["<start>"], end_unit=word_map["<end>"], action=action, max_len=130, beam_size=10)

In [64]:
def u2s2t(seq):
    words = [rev_word_map[ind] for ind in seq if rev_word_map[ind] not in special_words]
    sequence = np.array(text_to_sequence(' '.join(words), ['english_cleaners']))[None, :]
    sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
    _, mel_outputs_postnet, _, _ = tacotron2_model.inference(sequence)
    with torch.no_grad():
        x = mel_outputs_postnet.squeeze().to(device)
        y_g_hat = generator(mel_outputs_postnet)
        audio = y_g_hat.squeeze()
        
        # audio = audio * 32768.0
        # audio = audio.cpu().numpy().astype('int16')

        audio = audio.cpu().numpy().astype(np.float64)

        import IPython.display as ipd
        display(ipd.Audio(audio, rate=22050))
        input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.float()
        logits = asr_model(input_values.to(device)).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.decode(predicted_ids[0])
        print(transcription)

In [65]:
u2s2t(seq_gt)


there are three tomatoes in a white background


In [66]:
u2s2t(seq)

three tomatoes in a white background


In [48]:
print(seq_gt)
print(seq)

[956, 1, 2, 57, 58, 3, 4, 5, 6, 56, 147, 347, 297, 214, 378, 677, 611, 353, 608, 403, 560, 502, 358, 122, 109, 110, 19, 273, 390, 403, 391, 502, 443, 881, 668, 579, 848, 944, 36, 347, 125, 64, 354, 645, 299, 300, 447, 407, 119, 32, 778, 731, 81, 104, 34, 66, 225, 462, 484, 373, 182, 145, 312, 333, 415, 488, 139, 458, 90, 269, 152, 153, 25, 235, 638, 377, 77, 28, 29, 98, 99, 100, 256, 3, 101, 1, 957]
[956, 1, 2, 57, 58, 3, 4, 5, 6, 7, 8, 109, 259, 273, 390, 336, 560, 442, 443, 881, 627, 577, 111, 243, 797, 739, 328, 284, 297, 354, 299, 400, 582, 611, 407, 353, 351, 359, 120, 34, 268, 269, 83, 788, 862, 86, 124, 36, 125, 43, 71, 313, 91, 348, 131, 153, 94, 50, 647, 112, 261, 265, 360, 232, 565, 98, 99, 100, 3, 101, 1, 957]
