# Imports

In [10]:
from mdx.tfc_tdf_v3 import TFC_TDF_net, STFT
from mdx.tfc_tdf_v3 import TFC_TDF_net, STFT
import mdx.mdxnet as MdxnetSet
from mdx import spec_utils
from mdx.constants import secondary_stem
import onnxruntime as ort
from onnx import load
from onnx2pytorch import ConvertModel
from ml_collections import ConfigDict

import torch
import audiofile
from IPython.display import Audio, display
import soundfile as sf
import json 
import hashlib
import librosa
import numpy as np
import audioread
import platform
from numpy.typing import NDArray
from typing import Union
import math, os
import yaml

if torch.cuda.is_available(): device = "cuda"
elif torch.backends.mps.is_available(): device = torch.device("mps")
else: device = "cpu"

# Main code

## Load model config

In [12]:

def load_mdxc_models_data(model_path:str="mdxc/modelparams/model_data.json")->dict:
    """
    Load the mdxc models data from the specified model path.

    Args:
        model_path (str): The path to the model data JSON file. Default is "mdxc/modelparams/model_data.json".

    Returns:
        dict: The loaded models data.
    """

    models_data = json.load(open(model_path))
    return models_data


def get_model_hash_from_path(model_path:str="./mdxc/weights/MDX23C-8KFFT-InstVoc_HQ/MDX23C-8KFFT-InstVoc_HQ.ckpt")->str:
    """
    Get the hash of the model from the specified model path.

    Args:
        model_path (str): The path to the model file. Default is "./mdxc/weights/UVR-MDX-NET-Inst_1/UVR-MDX-NET-Inst_1.ckpt".

    Returns:
        str: The hash of the model.
    """

    try:
        with open(model_path, 'rb') as f:
            f.seek(- 10000 * 1024, 2)
            model_hash = hashlib.md5(f.read()).hexdigest()
    except:
        model_hash = hashlib.md5(open(model_path,'rb').read()).hexdigest()
    
    return model_hash


def load_mdxc_model_data(models_data, model_hash, model_path="./mdxc/modelparams"):
    """
    Load the mdxc model data from the specified models data and model hash.

    Args:
        models_data (dict): The models data.
        model_hash (str): The hash of the model.

    Returns:
        dict: The loaded model data.
    """

    model_data_src = models_data[model_hash]
    # if not "config_yaml" in model_data_src: return model_data_src
    model_path = os.path.join(model_path, "mdx_c_configs", model_data_src['config_yaml'])
    model_data = yaml.load(open(model_path), Loader=yaml.FullLoader)

    model_data = ConfigDict(model_data)
    
    return model_data

models_data = load_mdxc_models_data(model_path="mdxc/modelparams/model_data.json")
model_hash = get_model_hash_from_path(model_path="./mdxc/weights/MDX23C-8KFFT-InstVoc_HQ/MDX23C-8KFFT-InstVoc_HQ.ckpt")

model_data = models_data[model_hash]

model_data = load_mdxc_model_data(models_data, model_hash, model_path="./mdxc/modelparams")


model_data

audio:
  chunk_size: 261120
  dim_f: 4096
  dim_t: 256
  hop_length: 1024
  min_mean_abs: 0.001
  n_fft: 8192
  num_channels: 2
  sample_rate: 44100
inference:
  batch_size: 1
  dim_t: 256
  num_overlap: 8
model:
  act: gelu
  bottleneck_factor: 4
  growth: 128
  norm: InstanceNorm
  num_blocks_per_scale: 2
  num_channels: 128
  num_scales: 5
  num_subbands: 4
  scale:
  - 2
  - 2
training:
  augmentation: 1
  augmentation_mix: true
  augmentation_type: simple1
  batch_size: 6
  coarse_loss_clip: true
  ema_momentum: 0.999
  grad_clip: 0
  instruments:
  - Vocals
  - Instrumental
  lr: 1.0e-05
  num_epochs: 1000
  num_steps: 1000
  patience: 2
  q: 0.95
  reduce_factor: 0.95
  target_instrument: null

## Load Model

In [13]:
def load_modle(model_path:str, model_data:ConfigDict, device:str='cuda'):
    """
    Load the model from the given path and return the loaded model.

    Args:
        model_path (str): The path to the model file.
        model_data (int): The model data of type ConfigDict.
        device (str): The device to load the model on. Defaults to 'cuda'.

    Returns:
        model_run (function): The loaded model.

    """
    model = TFC_TDF_net(model_data, device=device)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.to(device).eval()
    return model


model_run = load_modle("./mdxc/weights/MDX23C-8KFFT-InstVoc_HQ/MDX23C-8KFFT-InstVoc_HQ.ckpt",
                       model_data, device)


## Load data

In [14]:

def rerun_mp3(audio_file:NDArray, sample_rate:int=44100):
    """
    Load an audio file and return the audio data.

    Parameters:
        audio_file (str): The path to the audio file.
        sample_rate (int, optional): The desired sample rate of the audio data. Default is 44100.

    Returns:
        numpy.ndarray: The audio data as a numpy array.
    """
    with audioread.audio_open(audio_file) as f:
        track_length = int(f.duration)

    return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0]

def prepare_mix(mix):
    
    audio_path = mix

    if not isinstance(mix, np.ndarray):
        mix, sr = librosa.load(mix, mono=False, sr=44100)
    else:
        mix = mix.T

    if isinstance(audio_path, str):
        if not np.any(mix) and audio_path.endswith('.mp3'):
            mix = rerun_mp3(audio_path)

    if mix.ndim == 1:
        mix = np.asfortranarray([mix,mix])

    return mix



audio_file = "/Users/mohannadbarakat/Downloads/t.wav"
mix = prepare_mix(audio_file)
mix.shape

(2, 17822209)

## Run Model

In [16]:
def pitch_fix(source, sr_pitched, org_mix, semitone_shift):
    source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=semitone_shift)[0]
    source = spec_utils.match_array_shapes(source, org_mix)
    return source

In [21]:
segment_size = 256
prams = {
    'is_mdx_c_seg_def': False,
    'segment_size': segment_size,
    'batch_size': 1,
    'overlap_mdx23': 8,
    'semitone_shift': 0,
    # 'mdx_segment_size': segment_size
}


def demix(mix, prams, model, model_data, device='cpu'):
    sr_pitched = 441000
    org_mix = mix
    semitone_shift = prams['semitone_shift']
    if  semitone_shift != 0:
        mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-semitone_shift)

    
    mix = torch.tensor(mix, dtype=torch.float32)

    try:
        S = model.num_target_instruments
    except Exception as e:
        S = model.module.num_target_instruments

    if prams['is_mdx_c_seg_def']:
        mdx_segment_size = model_data.inference.dim_t  
    else:
        mdx_segment_size = prams['segment_size']
    
    batch_size = prams['batch_size']
    chunk_size = model_data.audio.hop_length * (mdx_segment_size - 1)
    overlap = prams['overlap_mdx23']

    hop_size = chunk_size // overlap
    mix_shape = mix.shape[1]
    pad_size = hop_size - (mix_shape - chunk_size) % hop_size
    mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)

    chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
    batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
    
    X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix)
    X = X.to(device)

    with torch.no_grad():
        cnt = 0
        for batch in batches:
            x = model(batch.to(device))
            
            for w in x:
                X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w
                cnt += 1

    estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap
    del X
    pitch_fix = lambda s:pitch_fix(s, sr_pitched, org_mix, semitone_shift)

    if S > 1:
        sources = {k: pitch_fix(v) if semitone_shift!=0 else v for k, v in zip(model_data.training.instruments, estimated_sources.cpu().detach().numpy())}
        del estimated_sources   
        return sources
    
    est_s = estimated_sources.cpu().detach().numpy()
    del estimated_sources

    if semitone_shift!=0:
        return pitch_fix(est_s)  
    else:
        return est_s


In [22]:
stems = demix(mix, prams, model_run, model_data, device)

Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/SpectralOps.cpp:879.)
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


In [24]:
stems

{'Vocals': array([[ 1.1634671e-06,  8.9558523e-07,  1.6635396e-06, ...,
          4.0641185e-06,  3.4885775e-06, -7.1862928e-07],
        [ 1.7248444e-06,  2.1618505e-06,  1.6717951e-06, ...,
          2.0794473e-06,  3.8903972e-06,  1.0307468e-06]], dtype=float32),
 'Instrumental': array([[-1.2724404e-06, -1.0384624e-06, -1.7787546e-06, ...,
          2.2797427e-05,  3.2313128e-05,  8.8528031e-07],
        [-1.8492387e-06, -2.1128867e-06, -1.8117911e-06, ...,
         -1.4569440e-05,  1.9893885e-05, -1.0879556e-06]], dtype=float32)}

In [39]:
from mdx.constants import  MDX_NET_FREQ_CUT

def get_secondery_stems(source, mix, prams, device='cpu'):
    mdx_net_cut = False

    if (prams['primary_stem'] in MDX_NET_FREQ_CUT) and prams['is_match_frequency_pitch']:
        mdx_net_cut = True

    if mdx_net_cut:
        raw_mix = demix(match_frequency_pitch(mix, prams), prams, device=device, is_match_mix=True)  
    else:
        match_frequency_pitch(mix, prams)

    if prams['is_invert_spec']:
        secondary_source = spec_utils.invert_stem(raw_mix, source) 
    else: 
        secondary_source = mix.T-source.T
    
    return secondary_source


In [40]:
second_stem = get_secondery_stems(stems, mix, prams, device='cpu')


  tar_waves = result / divider


In [47]:
def nparray_stem_to_dict(stems, second_stem, model_data):
    if stems.shape[0] != 2:
        stems = stems.T
    if second_stem.shape[0] != 2:
        second_stem = second_stem.T
    return {
        model_data['primary_stem'].lower(): stems,
        secondary_stem(model_data['primary_stem']).lower(): second_stem
    }


dect_stems = nparray_stem_to_dict(stems, second_stem, model_data)

dect_stems['vocals'].shape

(2, 17822209)

## Test run

In [48]:
model_samplerate = 44100
path = "vocals.wav"
audiofile.write(path, dect_stems['vocals'], model_samplerate)

In [49]:
path = "instrumental.wav"
audiofile.write(path, dect_stems['instrumental'], model_samplerate)