In [1]:
import os
import sys
import torch
import collections
from itertools import product

import speechbrain as sb
from torch.cuda.amp import autocast
from hyperpyyaml import load_hyperpyyaml

device = 'cuda'

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

HPARAM_FILE = 'hparams/convtasnet_llama2_lora/run_convtasnet.yaml'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
argv = [HPARAM_FILE]
argv += ['--save_folder', 'save/convtasnet_llama2_lora']

# Zero-shot
argv += ['--case', '2Speech2FSD']
argv += ['--n_test', '5']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

# Init model
with open(hparam_file) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
# Init data
test_loader = torch.utils.data.DataLoader(
    hparams['test_set'],
    **hparams['test_loader_opts']
)

# Load model weights
loaded = hparams['checkpointer'].recover_if_possible()
print(loaded)

# Put model on GPU
for name, mod in hparams['modules'].items():
    mod.to(device)
    mod.eval()
    print(f'Load {name} to {device}.')

Initialized ShortTemplate: 
shuffle: True random: True
Fetched 5 manifest files.
Actions supported:  ['0', '1', 'D', 'U']  with volume_scale = 2
Tasks supported:  ['HE', 'HVC', 'OVC', 'RHVC', 'SE', 'SR', 'S↑', 'S↓', 'TAE', 'TAR', 'TA↑', 'TA↓', 'TSE', 'TSR', 'TS↑', 'TS↓']
Use GPT prompts with prob 1.0 and handcrafted prompts with prob 0.0.
Use FiLM at (every) block.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Checkpoint(path=PosixPath('save/convtasnet_llama2_lora/CKPT+2023-12-10+01-02-35+00'), meta={'end-of-epoch': True, 'loss': tensor(-7.8143, device='cuda:0'), 'nan_ratio': 0.0, 'snr': tensor(11.7403, device='cuda:0'), 'snri': tensor(7.8143, device='cuda:0'), 'unixtime': 1702188155.5799286}, paramfiles={'decoder': PosixPath('save/convtasnet_llama2_lora/CKPT+2023-12-10+01-02-35+00/decoder.ckpt'), 'masknet': PosixPath('save/convtasnet_llama2_lora/CKPT+2023-12-10+01-02-35+00/masknet.ckpt'), 'encoder': PosixPath('save/convtasnet_llama2_lora/CKP

In [5]:
for name, mod in hparams['modules'].items():
    mod.to(device)
    mod.eval()
    # Save the entire model
    torch.save(mod, f'save/{name}_full_model.pth')
    print(f'Saved full {name} model.')


Saved full encoder model.
Saved full decoder model.
Saved full masknet model.


In [8]:
# Load the entire model back
encoder_model = torch.load('save/masknet_full_model.pth')
encoder_model.to(device)
encoder_model.eval()


MaskNet(
  (layer_norm): ChannelwiseLayerNorm()
  (bottleneck_conv1x1): Conv1d(
    (conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
  )
  (temporal_conv_net): FilmTemporalBlocksSequential(
    (filmtemporalblock_0_0): FilmTemporalBlock(
      (layers): Sequential(
        (conv): Conv1d(
          (conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
        )
        (act): PReLU(num_parameters=1)
        (norm): GlobalLayerNorm()
        (DSconv): DepthwiseSeparableConv(
          (conv_0): Conv1d(
            (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), groups=512, bias=False)
          )
          (act): PReLU(num_parameters=1)
          (act_0): GlobalLayerNorm()
          (conv_1): Conv1d(
            (conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
          )
        )
      )
      (film): FiLM(
        (scaler): Sequential(
          (0): Linear(in_features=4096, out_features=128, bias=True)
          (1): ReLU(i