# Colab inference for ZFTurbo's [Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training/)


<font size=1>*made by [jarredou](https://github.com/jarredou) & deton</font>  
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/Q5Q811R5YI)

<font size=1>Visit [models list](https://docs.google.com/document/d/17fjNvJzj8ZGSer7c7OFe_CNfUKbAxEh_OBv94ZdRG5c/edit?tab=t.0#heading=h.2vdz5zlpb27h) for their descriptions</font>  

In [None]:
#@markdown #GDrive connection
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import base64
#@markdown # Install

%cd /content
!git clone -b colab-inference https://github.com/jarredou/Music-Source-Separation-Training

#requirements fix by santilli_
req_text = """
mutagen==1.47.0
ml_collections==1.1.0
numpy>=1.26.0
pandas==2.2.2
scipy
tqdm
segmentation_models_pytorch==0.3.3
timm
audiomentations
pedalboard
omegaconf
beartype
rotary_embedding_torch==0.3.5
einops
# librosa==0.11.0
demucs #==4.0.0
# transformers==4.35.0
torchmetrics==0.11.4
spafe==0.3.2
protobuf
torch_audiomentations
asteroid==0.7.0
auraloss
torchseg
"""

with open("Music-Source-Separation-Training/requirements.txt", "w") as f:
    f.write(req_text)

!mkdir '/content/Music-Source-Separation-Training/ckpts'

print('Installing the dependencies... This will take a few minutes')
!pip install -r 'Music-Source-Separation-Training/requirements.txt' &> /dev/null

print('Installation is done !')

In [None]:
%cd '/content/Music-Source-Separation-Training/'
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import torch
import yaml
from urllib.parse import quote

class IndentDumper(yaml.Dumper):
    def increase_indent(self, flow=False, indentless=False):
        return super(IndentDumper, self).increase_indent(flow, False)


def tuple_constructor(loader, node):
    # Load the sequence of values from the YAML node
    values = loader.construct_sequence(node)
    # Return a tuple constructed from the sequence
    return tuple(values)

# Register the constructor with PyYAML
yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple',
tuple_constructor)

def conf_edit(config_path, chunk_size, overlap):
    with open(config_path, 'r') as f:
        data = yaml.load(f, Loader=yaml.SafeLoader)

    # handle cases where 'use_amp' is missing from config:
    if 'use_amp' not in data.keys():
      data['training']['use_amp'] = True

    data['audio']['chunk_size'] = chunk_size
    data['inference']['num_overlap'] = overlap

    if data['inference']['batch_size'] == 1:
      data['inference']['batch_size'] = 2

    print("Using custom overlap and chunk_size values:")
    print(f"overlap = {data['inference']['num_overlap']}")
    print(f"chunk_size = {data['audio']['chunk_size']}")
    print(f"batch_size = {data['inference']['batch_size']}")

    with open(config_path, 'w') as f:
        yaml.dump(data, f, default_flow_style=False, sort_keys=False, Dumper=IndentDumper, allow_unicode=True)

def download_file(url):
    # Encode the URL to handle spaces and special characters
    encoded_url = quote(url, safe=':/')

    path = 'ckpts'
    os.makedirs(path, exist_ok=True)
    filename = os.path.basename(encoded_url)
    file_path = os.path.join(path, filename)

    if os.path.exists(file_path):
        print(f"File '{filename}' already exists at '{path}'.")
        return

    try:
        response = torch.hub.download_url_to_file(encoded_url, file_path)
        print(f"File '{filename}' downloaded successfully")
    except Exception as e:
        print(f"Error downloading file '{filename}' from '{url}': {e}")


# Stripped model list in this copy to download from


#@markdown # Separation
#@markdown #### Separation config:
input_folder = '/content/drive/MyDrive/input' #@param {type:"string"}
output_folder = '/content/drive/MyDrive/output' #@param {type:"string"}
model = 'CENTER-MDX23C-271'
extract_instrumental = 0 #@param {type:"slider", min:0, max:2, step:1}
export_format = 'wav FLOAT' #@param ['wav FLOAT', 'flac PCM_16', 'flac PCM_24']
use_tta = False #@param {type:"boolean"}
swap_stereo = False #@param {type:"boolean"}
#@markdown ---
#@markdown *Roformers custom config:*
overlap = 2 #@param {type:"slider", min:2, max:40, step:1}
chunk_size = "485100" #@param [88200, 112455, 132300, 156555, 176400, 352800, 485100, 529200, 588800, 587412, 661500, 749259] {allow-input: true}
chunk_size = int(chunk_size)

if export_format.startswith('flac'):
    flac_file = True
    pcm_type = export_format.split(' ')[1]
else:
    flac_file = False
    pcm_type = None

if model == 'CENTER-MDX23C-271':
    model_type = 'mel_band_roformer'
    config_path = '/content/Music-Source-Separation-Training/MelBand Roformer Similarity/config_mel_band_roformer_similarity.yaml'
    start_check_point = '/content/Music-Source-Separation-Training/MelBand Roformer Similarity/model_mel_band_roformer_ep_25_sdr_15.0049.ckpt'
    !gdown --folder "https://drive.google.com/drive/folders/1uJP5OQuChCQVY4CVB1Ju3nxBskE-dYzy?usp=drive_link"

supported_extensions = {'.aac', '.aif', '.aiff', '.flac', '.m4a', '.mp3', '.ogg', '.opus', '.wav', '.wv'}
input_path = Path(input_folder)

if not input_path.exists():
    raise FileNotFoundError(f"Input folder '{input_folder}' does not exist.")

audio_files = sorted(
    [path for path in input_path.iterdir() if path.is_file() and path.suffix.lower() in supported_extensions]
)

if not audio_files:
    raise ValueError(f"No supported audio files found in '{input_folder}'.")

def build_command(file_path):
    cmd = [
        "python",
        "inference.py",
        "--model_type", model_type,
        "--config_path", config_path,
        "--start_check_point", start_check_point,
        "--input-file", str(file_path),
        "--store_dir", output_folder,
    ]
    if extract_instrumental:
        cmd.append("--extract_instrumental")
    if flac_file:
        cmd.append("--flac_file")
        if pcm_type:
            cmd.extend(["--pcm_type", pcm_type])
    elif pcm_type:
        cmd.extend(["--pcm_type", pcm_type])
    if use_tta:
        cmd.append("--use_tta")
    if swap_stereo:
        cmd.append("--swap_stereo")
    return cmd

def run_inference(file_path):
    cmd = build_command(file_path)
    print(f"Starting inference for {file_path.name}")
    result = subprocess.run(cmd, stderr=subprocess.PIPE, text=True)
    if result.returncode != 0:
        stderr = result.stderr.strip() or 'Unknown error'
        raise RuntimeError(f"Inference failed for {file_path.name}: {stderr}")
    print(f"Finished inference for {file_path.name}")

# Dispatch inference calls concurrently, up to three workers.
max_workers = min(3, len(audio_files))
failures = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = {executor.submit(run_inference, file_path): file_path for file_path in audio_files}
    for future in as_completed(futures):
        file_path = futures[future]
        try:
            future.result()
        except Exception as error:
            print(error)
            failures.append(file_path.name)

if failures:
    print("Inference completed with errors:", ", ".join(failures))
else:
    print("Inference finished for all files.")

**INST-Mel-Roformers like v1/1x/2** have switched output file names - files labelled as vocals are instrumentals <br>
But e.g. not v1e - so if you uncheck extract_instrumentals for it, only one stem called "other" will be rendered, and it will be instrumental.<br><br>
**Mel Karaoke by becruily and Duality** models output 2 stems, and don't need "**extract_instrumental**" option enabled (inverted stem will rather have worse quality, than the model output, plus you won't end up with 3 output files).
<br><br>
**TTA** - results in longer separation time, "it gives a little better SDR score but hard to tell if it's really audible in most cases". <br> it “means "test time augmentation", (...) it will do 3 passes on the audio file instead of 1. 1 pass with be with original audio. 1 will be with inverted stereo (L becomes R, R become L). 1 will be with phase inverted and then results are averaged for final output. ” - jarredou
<br><br>
**Overlap** - higher means longer separation time. 4 is already balanced value, 2 is fast and some people still won't notice any difference. Normally there's not point going over 8.<br><br>
**Chunk_size** - most models use the default 485100 (besides Beta6X - 529200, 6 stems SW - 588800 [882000 will probably crash on free T4], and Amane's 4 stems Large - 661500), and achieves the highest SDR with this value (higher than training chunk), but some people occasionally use also 112455 to get different results.<br><br>

If your separation can't start and "**Total files found: 0**" is shown, be aware that: <br>1) Input must be a path to a folder containing audio files, not direct path to an audio file<br> 2) The Colab is case aware - e.g. call your folder "input" not "Input".<br> 3) Check if your Google Drive mounting was executed correctly. Open file manager on the left to check if your drive folder is not empty. If it's the case, force remount with the following line:

In [None]:
drive.mount("/content/drive", force_remount=True)

4) Consider uploading your files to input folder on GDrive before running the Colab - rarely it may happen that the files in the file manager might be invisible despite refreshing the files view - then you can launch the cell above 2 or 3 times to fix it.
<br>

**"Propagation unsuccessful" error** - click on connect again button in the right top corner (might happen if you delete the envrinment manually and start from the first cell again).