In [14]:
from model import Net
import torch
import torchaudio
import time
import numpy as np
import json
import os
from utils import glob_audio_files
from tqdm import tqdm


def load_model(checkpoint_path, config_path):
    with open(config_path) as f:
        config = json.load(f)
    model = Net(**config['model_params'])
    model.load_state_dict(torch.load(
        checkpoint_path, map_location="cpu")['model'])
    return model, config['data']['sr']


def load_audio(audio_path, sample_rate):
    audio, sr = torchaudio.load(audio_path)
    audio = audio.mean(0, keepdim=False)
    audio = torchaudio.transforms.Resample(sr, sample_rate)(audio)
    return audio


def save_audio(audio, audio_path, sample_rate):
    torchaudio.save(audio_path, audio, sample_rate)


def infer(model, audio):
    return model(audio.unsqueeze(0).unsqueeze(0)).squeeze(0)


def infer_stream(model, audio, chunk_factor, sr):
    L = model.L
    chunk_len = model.dec_chunk_size * L * chunk_factor
    # pad audio to be a multiple of L * dec_chunk_size
    original_len = len(audio)
    if len(audio) % chunk_len != 0:
        pad_len = chunk_len - (len(audio) % chunk_len)
        audio = torch.nn.functional.pad(audio, (0, pad_len))

    # scoot audio down by L
    audio = torch.cat((audio[L:], torch.zeros(L)))
    audio_chunks = torch.split(audio, chunk_len)
    # add lookahead context from prev chunk
    new_audio_chunks = []
    for i, a in enumerate(audio_chunks):
        if i == 0:
            front_ctx = torch.zeros(L * 2)
        else:
            front_ctx = audio_chunks[i - 1][-L * 2:]
        new_audio_chunks.append(torch.cat([front_ctx, a]))
    audio_chunks = new_audio_chunks
    print(audio_chunks[0].shape)
    outputs = []
    times = []
    with torch.inference_mode():
        enc_buf, dec_buf, out_buf = model.init_buffers(
            1, torch.device('cpu'))
        if hasattr(model, 'convnet_pre'):
            convnet_pre_ctx = model.convnet_pre.init_ctx_buf(
                1, torch.device('cpu'))
        else:
            convnet_pre_ctx = None
        for chunk in audio_chunks:
            start = time.time()
            output, \
                enc_buf, dec_buf, out_buf, \
                convnet_pre_ctx = model(chunk.unsqueeze(
                    0).unsqueeze(0),
                    enc_buf, dec_buf, out_buf,
                    convnet_pre_ctx, pad=(not model.lookahead)
                )
            outputs.append(output)
            times.append(time.time() - start)
        # concatenate outputs
    outputs = torch.cat(outputs, dim=2)
    # Calculate RTF
    avg_time = np.mean(times)
    rtf = (chunk_len / sr) / avg_time
    # calculate e2e latency
    e2e_latency = ((2 * L + chunk_len) / sr + avg_time) * 1000
    # remove padding
    outputs = outputs[:, :, :original_len].squeeze(0)
    return outputs, rtf, e2e_latency


def do_infer(model, audio, chunk_factor, sr, stream):
    with torch.no_grad():
        if stream:
            outputs, rtf, e2e_latency = infer_stream(
                model, audio, chunk_factor, sr)
            return outputs, rtf, e2e_latency
        else:
            outputs = infer(model, audio)
            rtf = None
            e2e_latency = None
    return outputs, rtf, e2e_latency


  from .autonotebook import tqdm as notebook_tqdm
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [allow_tf32, disable_jit_profiling]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []


In [17]:
!ls checkpoints

llvc  llvc_hfg	llvc_nc


In [19]:


dir_test ='test_sample'
model, sr = load_model("checkpoints/llvc/G_500000.pth", 'experiments/llvc/config.json')
os.mkdir('test_sample')
# check if fname is a directory
audio = load_audio('test_wavs/174-50561-0000.wav', sr)
out, rtf, e2e_latency = do_infer(
    model, audio, 1, sr, True
)
# out_fname = os.path.join(
    # args.out_dir, os.path.basename(args.fname))
save_audio(out, f'{dir_test}/test.wav',sr)



  model.load_state_dict(torch.load(


torch.Size([240])


In [22]:

import torch 
torch.rand((240,)).unsqueeze(0).unsqueeze(0).shape

torch.Size([1, 1, 240])

In [24]:
# model.__dict__

In [29]:
test = set()
for k in model.parameters():
    test.add(k.dtype)
test

{torch.float32}

In [36]:
from discriminators import MultiPeriodDiscriminator, discriminator_loss, generator_loss, feature_loss


In [38]:
from hfg_disc import ComboDisc


In [37]:
test_d1 = MultiPeriodDiscriminator([2, 3, 5, 7, 11, 17, 23, 37])
for k in test_d1.parameters():
    test.add(k.dtype)
test

{torch.float32}

In [39]:
test_d1 = ComboDisc()
for k in test_d1.parameters():
    test.add(k.dtype)
test

  WeightNorm.apply(module, name, dim)


{torch.float32}

In [3]:
model

Net(
  (convnet_pre): CachedConvNet(
    (down_convs): ModuleList(
      (0-11): 12 x ResidualBlock(
        (filter): Conv1d(1, 1, kernel_size=(3,), stride=(1,))
        (gate): Conv1d(1, 1, kernel_size=(3,), stride=(1,))
        (dropout): Dropout1d(p=0.5, inplace=False)
      )
    )
  )
  (in_conv): Sequential(
    (0): Conv1d(1, 512, kernel_size=(48,), stride=(16,), bias=False)
    (1): ReLU()
  )
  (label_embedding): Sequential(
    (0): Linear(in_features=1, out_features=512, bias=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
  )
  (mask_gen): MaskNet(
    (encoder): DilatedCausalConvEncoder(
      (dcc_layers): Sequential(
        (dcc_0): DepthwiseSeparableConv(
          (layers): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), groups=512)
            (1): LayerNormPermu

In [9]:
audio.shape

torch.Size([64320])

In [40]:
from torchview import draw_graph

In [41]:
m_g = draw_graph(model,input_size=(1,1,240), expand_nested=True,device='cpu',save_graph=True)

In [43]:
# m_g.visual_graph

In [26]:
m_g

<torchview.computation_graph.ComputationGraph at 0x753bf0568dc0>

In [3]:
!ls experiments/

llvc  llvc_hfg	llvc_nc


In [4]:
import json 
with open('experiments/llvc/config.json') as f:
    data_comm = json.load(f)

with open('experiments/llvc_hfg/config.json') as f:
    data_hfg = json.load(f)    

with open('experiments/llvc_nc/config.json') as f:
    data_nc = json.load(f)

In [6]:
data_comm['model_params']

{'label_len': 1,
 'L': 16,
 'enc_dim': 512,
 'num_enc_layers': 8,
 'dec_dim': 256,
 'num_dec_layers': 1,
 'dec_buf_len': 13,
 'dec_chunk_size': 13,
 'out_buf_len': 4,
 'use_pos_enc': True,
 'decoder_dropout': 0.1,
 'convnet_config': {'convnet_prenet': True,
  'out_channels': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  'kernel_sizes': [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
  'dilations': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  'dropout': 0.5,
  'combine_residuals': None,
  'skip_connection': 'add',
  'use_residual_blocks': True}}

In [7]:
data_hfg['model_params']

{'label_len': 1,
 'L': 16,
 'enc_dim': 512,
 'num_enc_layers': 8,
 'dec_dim': 256,
 'num_dec_layers': 1,
 'dec_buf_len': 13,
 'dec_chunk_size': 13,
 'out_buf_len': 4,
 'use_pos_enc': True,
 'decoder_dropout': 0.1,
 'convnet_config': {'convnet_prenet': True,
  'out_channels': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  'kernel_sizes': [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
  'dilations': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  'dropout': 0.5,
  'combine_residuals': None,
  'skip_connection': 'add',
  'use_residual_blocks': True}}

In [10]:
data_hfg['discriminator']

'hfg'

In [33]:
data_nc['periods']

[2, 3, 5, 7, 11, 17, 23, 37]

In [13]:
data_comm['discriminator']

'rvc'