[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/Q5Q811R5YI)  
# Apollo-Colab-Inference [![Open In Github](https://img.shields.io/badge/github-code-green)](https://github.com/jarredou/Apollo-Colab-Inference/)  


*Original work [Apollo: Band-sequence Modeling for High-Quality Music Restoration in Compressed Audio](https://github.com/JusperLee/Apollo)*  

The model was trained to restore/enhance lossy mp3 audio with bitrate <= 128 kbps.  
<br>
___  
*changelog:*

<font size=2>**v0.5**  
<font size=2>- added: lew's universal model  

<font size=2>**v0.4**  
<font size=2>- added: config loader  
<font size=2>- added: lew's separated vocals enhancer v2 beta

<font size=2>**v0.3**  
<font size=2>- lew's separated vocals enhancer model added

<font size=2>**v0.2**  
<font size=2>- added overlap feature  
<font size=2>- new inference.py created for easier local CLI use  

<font size=2>**v0.1**  
<font size=2>- added chunking for long audio inputs  
<font size=2>- ~~added "dual mono" processing for stereo audio input (processing each channel independently)~~

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture --no-stderr
#@markdown #Install
%cd /content/
!git clone https://github.com/SUC-DriverOld/Apollo-Training.git && cd Apollo-Training

!rm -rf '/content/Apollo-Training/inference.py'
%cd /content/Apollo-Training
!wget 'https://raw.githubusercontent.com/Qupci/Apollo-Colab-Inference/main/inference.py'

!pip install omegaconf ml_collections

!yes | pip uninstall -y torch torchvision torchaudio
!yes | pip install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

In [None]:
%cd /content/Apollo-Training

import os
import subprocess

def download_file(url, save_path):
    """Download a file using wget via subprocess and save to save_path.
    Uses -c to continue partial downloads and -O to specify output path.
    Raises RuntimeError if wget exits with non-zero code."""
    # Ensure parent directory exists
    os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
    cmd = ["wget", "-c", "--show-progress", "-O", save_path, url]
    proc = subprocess.run(cmd)
    if proc.returncode != 0:
        raise RuntimeError(f"wget failed with exit code {proc.returncode} for url {url}")
    print(f"Descargado: {save_path}")

models_folder = '/content/Apollo-Training/model'
configs_folder = '/content/Apollo-Training/configs'
# Ensure model and config folders exist before attempting downloads
os.makedirs(models_folder, exist_ok=True)
os.makedirs(configs_folder, exist_ok=True)

#@markdown #Inference
#@markdown For the universal model set *chunk_size* above to 19, for all other models set it to 25
input_folder_path = '/content/drive/MyDrive/input' #@param {type:"string"}
output_folder_path = '/content/drive/MyDrive/output' #@param {type:"string"}
model = 'Baicai1145 Vocal MSST' #@param ['MP3 Enhancer', 'Lew Vocal Enhancer', 'Lew Vocal Enhancer v2 (beta)', 'Lew Universal Lossy Enhancer', 'Baicai1145 Vocal MSST']
chunk_size = 25 #@param {type:"slider", min:3, max:25, step:1}
overlap = 2 #@param {type:"slider", min:2, max:10, step:1}

if model == 'MP3 Enhancer':
    # URL from the install cell (was previously fetched with wget)
    model_url = "https://huggingface.co/JusperLee/Apollo/resolve/main/pytorch_model.bin"
    model_filename = "pytorch_model.bin"
    download_file(model_url, os.path.join(models_folder, model_filename))
    ckpt = os.path.join(models_folder, model_filename)
    config = 'configs/apollo.yaml'
if model == 'Lew Vocal Enhancer':
    model_url = "https://huggingface.co/jarredou/lew_apollo_vocal_enhancer/resolve/main/apollo_model.ckpt"
    config_url = "https://huggingface.co/jarredou/lew_apollo_vocal_enhancer/resolve/main/config_apollo_vocal.yaml"
    model_filename = "apollo_model.ckpt"
    config_filename = "config_apollo_vocal.yaml"
    download_file(model_url, os.path.join(models_folder, model_filename))
    download_file(config_url, os.path.join(configs_folder, config_filename))
    ckpt = os.path.join(models_folder, model_filename)
    config = os.path.join(configs_folder, config_filename)
if model == 'Lew Vocal Enhancer v2 (beta)':
    model_url = "https://huggingface.co/jarredou/lew_apollo_vocal_enhancer/resolve/main/apollo_model_v2.ckpt"
    config_url = "https://huggingface.co/jarredou/lew_apollo_vocal_enhancer/resolve/main/config_apollo_vocal.yaml"
    model_filename = "apollo_model_v2.ckpt"
    config_filename = "config_apollo_vocal.yaml"
    download_file(model_url, os.path.join(models_folder, model_filename))
    download_file(config_url, os.path.join(configs_folder, config_filename))
    ckpt = os.path.join(models_folder, model_filename)
    config = os.path.join(configs_folder, config_filename)
if model == 'Lew Universal Lossy Enhancer':
    model_url = "https://github.com/deton24/Lew-s-vocal-enhancer-for-Apollo-by-JusperLee/releases/download/uni/apollo_model_uni.ckpt"
    config_url = "https://github.com/deton24/Lew-s-vocal-enhancer-for-Apollo-by-JusperLee/releases/download/uni/config_apollo_uni.yaml"
    model_filename = "apollo_model_uni.ckpt"
    config_filename = "config_apollo_uni.yaml"
    download_file(model_url, os.path.join(models_folder, model_filename))
    download_file(config_url, os.path.join(configs_folder, config_filename))
    ckpt = os.path.join(models_folder, model_filename)
    config = os.path.join(configs_folder, config_filename)
if model == 'Baicai1145 Vocal MSST':
    model_url = "https://huggingface.co/baicai1145/Apollo-vocal-msst/resolve/main/model_apollo_vocals_ep_54.ckpt"
    config_url = "https://huggingface.co/baicai1145/Apollo-vocal-msst/resolve/main/config_apollo_vocals_ep_54.yaml"
    model_filename = "model_apollo_vocals_ep_54.ckpt"
    config_filename = "config_apollo_vocals_ep_54.yaml"
    download_file(model_url, os.path.join(models_folder, model_filename))
    download_file(config_url, os.path.join(configs_folder, config_filename))
    ckpt = os.path.join(models_folder, model_filename)
    config = os.path.join(configs_folder, config_filename)

# Get list of input files
input_files = [os.path.join(input_folder_path, f) for f in os.listdir(input_folder_path)]

# Create output folder if it doesn't exist
if not os.path.exists(output_folder_path):
    os.makedirs(output_folder_path)

# Process each input file
for input_file in input_files:
    filename, file_extension = os.path.splitext(os.path.basename(input_file))
    output_file = os.path.join(output_folder_path, f"{filename}.wav")
    !python inference.py \
        --in_wav '{input_file}' \
        --out_wav '{output_file}' \
        --chunk_size {chunk_size} \
        --overlap {overlap} \
        --ckpt '{ckpt}' \
        --config '{config}'