<a href="https://colab.research.google.com/github/THEMANNICHOLAS/Stem-Separator/blob/main/Stem_Separation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#import the demucs and open-unmix models
!pip install demucs
!pip install openunmix



In [2]:
from demucs.audio import AudioFile
from demucs.apply import apply_model
from demucs.pretrained import get_model
from demucs.separate import load_track
from google.colab import files

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import librosa
import subprocess
import os


In [51]:
class Stem_Separator:


  def __init__(self):
    """
    Initialize the Stem Separator class and load the model.
    """
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.model = get_model('htdemucs')

    #Verify that the model can be properly imported and used.
    try:
      self.model = get_model('htdemucs')
    except Exception as e:
      print(f'Error loading model: {e}')



  def load_audio_file(self, user_file = None):
    """
    Loads an audio file, either from a user-provided path or by uploading.
    Converts the file into a PyTorch tensor and performs basic pre-processing
    (e.g., converting stereo to mono).

    Parameters:
      user_file (str): The path to the uploaded file if in Google Colab.

    Returns:
      None
    """
    if user_file:
      self.file = user_file
    else:
      uploaded_file = files.upload()
      self.file = list(uploaded_file.keys())[0]
      if uploaded_file:
        print(f'Uploaded file: {self.file}')
      else:
        print('No file uploaded')
    self.sample_rate = torchaudio.info(self.file).sample_rate


  def convert_file(self):
    file_path, file_extention = os.path.splitext(self.file)
    if file_extention.lower() != '.wav':
      #We will convert to .wav using ffmpeg
      converted_file = file_path + '.wav'
      try:
        subprocess.run([
            'ffmpeg',
            '-i', self.file,
            '-ar', '44100',
            '-ac', '2',
            '-y',
            converted_file
        ], check = True)
        return True
      except subprocess.CalledProcessError as e:
        print(f'Error converting file to .wav: {e}')
        return False


  def convert_to_audio_tensor(self):
    #Try to convert the audio file into a loadable torchaudio tensor.
    try:
        self.audio_tensor = torchaudio.load(self.file)[0]
    except Exception as e:
        print(f'Error loading or converting audio file: {e}')


  def apply_high_pass(self, cutoff = 20):
    """
    Applies a high-pass filter to the audio tensor.

    Parameters:
      cutoff (float): The cutoff frequency for the high-pass filter.
      sample_rate (int): The sample rate of the audio tensor.
    """
    self.audio_tensor = torchaudio.functional.highpass_biquad(
        self.audio_tensor, cutoff, self.sample_rate)
    return self.audio_tensor


  def apply_low_pass(self, cutoff = 20000):
    """
    Applies a low-pass filter to the audio tensor.

    Parameters:
      cutoff (float): The cutoff frequency for the low-pass filter.
      sample_rate (int): The sample rate of the audio tensor.
    """
    self.audio_tensor = torchaudio.functional.lowpass_biquad(
        self.audio_tensor, cutoff, self.sample_rate)
    return self.audio_tensor



  def preprocess_audio(self):
    """
    Preprocesses the audio by applying high-pass and low-pass filters.

    Returns:
        torch.Tensor: The preprocessed audio tensor.
    """
    self.convert_file()
    self.convert_to_audio_tensor()
    self.apply_high_pass()
    self.apply_low_pass()



  def stem_separate(self,  model = None, stem = None):
    """
    Separates the stem of the audio tensor given a choice from the user.

    Returns:
      torch.Tensor: The separated stem of the audio tensor.
    """
    # Model selection
    source_model = model.casefold() if model and model.casefold() in ['htdemucs', 'openunmix'] else 'htdemucs'

    # Stem selection
    stems = ['vocals', 'drums', 'bass', 'other']
    self.separation_type = stem.casefold() if stem and stem.casefold() in stems else 'vocals'

    extracted = apply_model(self.model,
                            self.audio_tensor.unsqueeze(0),
                            device = self.device,
                            split = True)
    print(extracted[0].shape)

    self.separated_stem = extracted[0]


  def save_audio(self):

 #   if self.separation_type:
 #     output_file_name = f'{self.file}_{self.separation_type}.wav'
 #   else:
 #     output_file_name = f'{self.file}_(vocals).wav'

    output_file_name = f'{self.file}_(vocals).wav'

    if self.separated_stem.dim()==3:
      self.separated_stem = self.separated_stem.squeeze(0)
    elif self.separated_stem.dim()>3:
      raise ValueError('Tensor has more than 3 dimensions')

    #Mix down to mono if stereo audio
    if self.separated_stem.shape[0] == 2:
      self.separated_stem = torch.mean(self.separated_stem, dim = 0, keepdim = True)

    torchaudio.save(output_file_name, self.separated_stem[3], self.sample_rate)
    files.download(output_file_name)








In [53]:

separator = Stem_Separator()

separator.load_audio_file()

separator.preprocess_audio()

stem = separator.stem_separate(model = 'htdemucs', stem = 'vocals')

separator.save_audio()


IndexError: list index out of range