In [2]:
from demucs import pretrained
import torch
from demucs.demucs import Demucs
from demucs.hdemucs import HDemucs
from demucs.apply import tensor_chunk
from demucs.htdemucs import HTDemucs
from demucs.utils import center_trim
from demucs.apply import TensorChunk
from demucs.audio import AudioFile, convert_audio, save_audio
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import scipy
from scipy.signal import resample, butter, filtfilt, cheby1
import os
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis
import warnings
import sys
import io
import torch.nn.utils.prune as prune
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from demucs.transformer import MyTransformerEncoderLayer, CrossTransformerEncoderLayer, dynamic_sparse_attention, MultiheadAttention, scaled_dot_product_attention
from torch.quantization import quantize_dynamic
from fractions import Fraction

In [3]:
from demucs.separate import Separator

device = "cuda" if torch.cuda.is_available() else "cpu"
separator = Separator(
    model="htdemucs",
    repo=None,
    device=device,
    shifts=1,
    overlap=0.25,
    split=True,
    segment=None,
    jobs=None,
    callback=print
)
segment = None
callback = None
length = None
samplerate = 44100
device

'cpu'

In [4]:
# Function to count the number of parameters of a torch model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
model_htdemucs = pretrained.get_model('htdemucs')
model_htdemucs.use_train_segment = False
teacher_model = model_htdemucs.models[0]
teacher_model.use_train_segment = False

In [61]:
teacher_kwargs = {'sources': ['drums', 'bass', 'other', 'vocals'], 'audio_channels': 2, 'samplerate': 44100, 'segment': Fraction(39, 5), 'channels': 48, 'channels_time': None, 'growth': 2, 'nfft': 4096, 'wiener_iters': 0, 'end_iters': 0, 'wiener_residual': False, 'cac': True, 'depth': 4, 'rewrite': True, 'multi_freqs': [], 'multi_freqs_depth': 3, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True, 'kernel_size': 8, 'stride': 4, 'time_stride': 2, 'context': 1, 'context_enc': 0, 'norm_starts': 4, 'norm_groups': 4, 'dconv_mode': 3, 'dconv_depth': 2, 'dconv_comp': 8, 'dconv_init': 0.001, 'bottom_channels': 512, 't_layers': 5, 't_hidden_scale': 4.0, 't_heads': 8, 't_dropout': 0.02, 't_layer_scale': True, 't_gelu': True, 't_emb': 'sin', 't_max_positions': 10000, 't_max_period': 10000.0, 't_weight_pos_embed': 1.0, 't_cape_mean_normalize': True, 't_cape_augment': True, 't_cape_glob_loc_scale': [5000.0, 1.0, 1.4], 't_sin_random_shift': 0, 't_norm_in': True, 't_norm_in_group': False, 't_group_norm': False, 't_norm_first': True, 't_norm_out': True, 't_weight_decay': 0.0, 't_lr': None, 't_sparse_self_attn': False, 't_sparse_cross_attn': False, 't_mask_type': 'diag', 't_mask_random_seed': 42, 't_sparse_attn_window': 400, 't_global_window': 100, 't_sparsity': 0.95, 't_auto_sparsity': False, 't_cross_first': False, 'rescale': 0.1}

student_kwargs = {k: v for k, v in teacher_kwargs.items()}
student_kwargs['channels'] = 12 # 48
# student_kwargs['depth'] = 2 # 4
# student_kwargs['kernel_size'] = 4 # 8
student_kwargs['time_stride'] = 2 # 2
# student_kwargs['stride'] = 2 # 4
student_kwargs['t_layers'] = 5 # 5
# student_kwargs['t_heads'] = 4 # 8
student_model = HTDemucs(**student_kwargs)
student_model.use_train_segment = False

In [62]:
print(f"{count_parameters(teacher_model):,} parameters in teacher model")
print(f"{count_parameters(student_model):,} parameters in student model")

41,984,456 parameters in teacher model
32,367,128 parameters in student model


In [74]:
audio_input = torch.randn(1, 2, 44100)  # Example input
# Forward pass through the model
teacher_start = time.time()
with torch.no_grad():
    teacher_separated_sources = teacher_model(audio_input)
teacher_end = time.time()
print("Time taken for teacher model: ", teacher_end - teacher_start)
print("Teacher model output shape: ", teacher_separated_sources.shape)
student_start = time.time()
with torch.no_grad():
    student_separated_sources = student_model(audio_input)
student_end = time.time()
print("Time taken for student model: ", student_end - student_start)
print("Student model output shape: ", student_separated_sources.shape)

Time taken for teacher model:  0.40648746490478516
Teacher model output shape:  torch.Size([1, 4, 2, 44100])
Time taken for student model:  0.25122547149658203
Student model output shape:  torch.Size([1, 4, 2, 44100])


In [41]:
print(f"{count_parameters(teacher_model):,} parameters in teacher model")

41,984,456 parameters in teacher model


In [None]:
def get_filtered_audio(file, method):
    wav = AudioFile(file).read(streams=0, samplerate=samplerate, channels=separator._audio_channels)
    original_length = wav.shape[1]
    if method[0] is None:
        return wav, original_length
    elif method[0] == "decimation_without_filtering":
        decimation_factor = method[1]
        wav = wav[:, ::decimation_factor]
        return wav, original_length
    elif method[0] == "decimation_with_butterworth_filter":
        cutoff, order, decimation_factor = method[1]
        nyquist = 0.5 * samplerate
        normal_cutoff = cutoff / nyquist
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        wav = filtfilt(b, a, wav, axis=1)
        wav = wav[:, ::decimation_factor]
        wav_tensor = torch.tensor(np.copy(wav), dtype=torch.float32)
        return wav_tensor, original_length
    elif method[0] == "decimation_with_chebyshev_filter":
        cutoff, order, ripple, decimation_factor = method[1]
        nyquist = 0.5 * samplerate
        normal_cutoff = cutoff / nyquist
        b, a = cheby1(order, ripple, normal_cutoff, btype='low', analog=False)
        wav = filtfilt(b, a, wav, axis=1)
        wav = wav[:, ::decimation_factor]
        wav_tensor = torch.tensor(np.copy(wav), dtype=torch.float32)
        return wav_tensor, original_length
    assert False, "Invalid method"

def interpolate_wav_file(wav, original_length):
    return resample(wav, original_length, axis=1)

def clean_up_out_wav(out, wav, original_length):
    wav = torch.tensor(resample(wav, original_length, axis=1))
    out = torch.tensor(resample(out, original_length, axis=3))
    return out, wav

In [9]:
# @track_emissions()
def run_separator_htdemucs(model, file, output_save_folder = "random_files", save_audio_flag=True, method=[None]):
    with torch.no_grad():
        os.makedirs(output_save_folder, exist_ok=True)
        wav, original_length = get_filtered_audio(file, method)
        ref = wav.mean(0)
        wav -= ref.mean()
        wav /= ref.std() + 1e-8
        mix = wav[None]
        # Assuming the rest of your code remains unchanged
        filename_format = "{stem}.{ext}"

        start_time = time.time()
        with torch.no_grad():
            out = model(mix)
        end_time = time.time()

        assert isinstance(out, torch.Tensor)
        out *= ref.std() + 1e-8
        out += ref.mean()
        wav *= ref.std() + 1e-8
        wav += ref.mean()
        out, wav = clean_up_out_wav(out, wav, original_length)
        separated = (wav, dict(zip(separator._model.sources, out[0])))[1]
        ext = "mp3"
        kwargs = {
            "samplerate": samplerate,
            "bitrate": 320,
            "clip": "rescale",
            "as_float": False,
            "bits_per_sample": 16,
        }
        last_ret = {}
        for stem, source in separated.items():
            stem_path = os.path.join(output_save_folder, filename_format.format(
                stem=stem,
                ext=ext,
            ))
            if save_audio_flag:
                save_audio(source, str(stem_path), **kwargs)
            else:
                last_ret[stem] = source
            # loaded_wav, _ = get_filtered_audio(stem_path, [None])
            # assert source.shape == loaded_wav.shape, f"{source.shape} != {loaded_wav.shape}"
        inference_time = end_time - start_time
        return inference_time, None, None, last_ret

In [10]:
run_separator_htdemucs(teacher_model, "my_test_short.mp4", method=[None])

NameError: name 'get_filtered_audio' is not defined