In [None]:
import torch, sys
import gradio as gr
import soundfile as sf
print(f"Using torch {torch.__version__}")
print(f"Using python {sys.version}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
import argparse
from pathlib import Path

import librosa
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio

from nets.model_wrapper import SeparationModel
from utils.audio_utils import resample
from utils.average_model_params import average_model_params
from utils.config import yaml_to_parser

RESAMPLE_RATE = 48000

# parameters used to plot the spectrogram
n_fft = 512
hop_length = 128

In [None]:
config_path = "pretrained_models/tuss.medium.2-4src/hparams.yaml"
ckpt_paths = [Path("pretrained_models/tuss.medium.2-4src/checkpoints/model.pth")]
# instantiate the model
hparams = yaml_to_parser(config_path)
hparams = hparams.parse_args([])
model = SeparationModel(
    hparams.encoder_name,
    hparams.encoder_conf,
    hparams.decoder_name,
    hparams.decoder_conf,
    hparams.model_name,
    hparams.model_conf,
    hparams.css_conf,
    hparams.variance_normalization,
)

In [None]:
state_dict = average_model_params(ckpt_paths)
new_state_dict = {}
for key, value in state_dict.items():
    k = key.replace("model.", "")
    new_state_dict[k] = value
model.load_state_dict(new_state_dict)
model.cuda()

In [None]:
def apply_model(audio_path,prompts):
  mix, fs = torchaudio.load(audio_path)
  mix = mix[[0],:4*fs]
  mix_return = mix.clone()
  mix = mix.cuda()
  if RESAMPLE_RATE != fs:
    mix = resample(mix, fs, RESAMPLE_RATE)
  with torch.no_grad():
    y, *_ = model(mix, [prompts])
  if RESAMPLE_RATE != fs:
    y = resample(y, RESAMPLE_RATE, fs)
  return y.cpu()

In [None]:
mix, fs = torchaudio.load("mix.wav")
print(f'fs: {fs}, len: {len(mix[0])}')

In [None]:
outputs = apply_model("mix.wav", ["speech", "sfxbg", "musicbg"])