In [1]:
import os
import sys
import torch
import collections
from itertools import product

import speechbrain as sb
from torch.cuda.amp import autocast
from hyperpyyaml import load_hyperpyyaml

device = 'cuda'

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

HPARAM_FILE = 'hparams/convtasnet_llama2_lora/run_convtasnet.yaml'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
argv = [HPARAM_FILE]
argv += ['--save_folder', 'save/convtasnet_llama2_lora']

# Zero-shot
argv += ['--case', '2Speech2FSD']
argv += ['--n_test', '5']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

# Init model
with open(hparam_file) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
# Init data
test_loader = torch.utils.data.DataLoader(
    hparams['test_set'],
    **hparams['test_loader_opts']
)

# Load model weights
loaded = hparams['checkpointer'].recover_if_possible()
print(loaded)

# Put model on GPU
for name, mod in hparams['modules'].items():
    mod.to(device)
    mod.eval()
    print(f'Load {name} to {device}.')

Initialized ShortTemplate: 
shuffle: True random: True
Fetched 5 manifest files.
Actions supported:  ['0', '1', 'D', 'U']  with volume_scale = 2
Tasks supported:  ['HE', 'HVC', 'OVC', 'RHVC', 'SE', 'SR', 'S↑', 'S↓', 'TAE', 'TAR', 'TA↑', 'TA↓', 'TSE', 'TSR', 'TS↑', 'TS↓']
Use GPT prompts with prob 1.0 and handcrafted prompts with prob 0.0.
Use FiLM at (every) block.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Checkpoint(path=PosixPath('save/convtasnet_llama2_lora/CKPT+2023-12-10+01-02-35+00'), meta={'end-of-epoch': True, 'loss': tensor(-7.8143, device='cuda:0'), 'nan_ratio': 0.0, 'snr': tensor(11.7403, device='cuda:0'), 'snri': tensor(7.8143, device='cuda:0'), 'unixtime': 1702188155.5799286}, paramfiles={'decoder': PosixPath('save/convtasnet_llama2_lora/CKPT+2023-12-10+01-02-35+00/decoder.ckpt'), 'masknet': PosixPath('save/convtasnet_llama2_lora/CKPT+2023-12-10+01-02-35+00/masknet.ckpt'), 'encoder': PosixPath('save/convtasnet_llama2_lora/CKP

In [3]:
def edit_sound(editor, mix, text_embed):
        
    # Encoding speech
    mix_h = editor['encoder'](mix)

    # Extraction
    est_mask = editor['masknet'](mix_h, text_embed).squeeze(0)
    est_tar_h = mix_h * est_mask # (B, F, T)

    # Decoding
    est_tar = editor['decoder'](est_tar_h)

    # T changed after conv1d in encoder, fix it here
    T_origin = mix.size(1)
    T_ext = est_tar.size(1)

    if T_origin > T_ext:
        est_tar = torch.functional.pad(est_tar, (0, T_origin - T_ext))
    else:
        est_tar = est_tar[:, :T_origin]

    return est_tar

def dummy_read_prompt(prompt, device='cuda'):
    B = len(prompt) # batch size
    return torch.rand((B, hparams['txt_emb_dim']), device=device)

In [4]:
with torch.no_grad():
    for _ in range(10): # for data in test_loader:
        # mix, tar, prompt, acts = data[0:4]
        mix = torch.rand(1, 80000).cuda()
        prompt = ('This is a placeholder.', )
        mix = mix.to(device)
        text_embed = dummy_read_prompt(prompt)
        est_tar = edit_sound(hparams['modules'], mix, text_embed)
        assert est_tar.shape == (1, 80000)

In [5]:
est_tar.shape

torch.Size([1, 80000])