In [1]:
from demucs import pretrained
import torch
from demucs.demucs import Demucs
from demucs.hdemucs import HDemucs
from demucs.apply import tensor_chunk
from demucs.htdemucs import HTDemucs
from demucs.utils import center_trim
from demucs.apply import TensorChunk
from demucs.audio import AudioFile, convert_audio, save_audio
from pathlib import Path
import demucs
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import scipy
from scipy.signal import resample, butter, filtfilt, cheby1
import os
import numpy as np
import torch
import warnings
import sys
import io
import torch.nn.utils.prune as prune
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from demucs.transformer import MyTransformerEncoderLayer, CrossTransformerEncoderLayer, dynamic_sparse_attention, MultiheadAttention, scaled_dot_product_attention
from torch.quantization import quantize_dynamic
from fractions import Fraction
import kd_helper
from demucs.solver import Solver
import logging
from demucs import distrib
import hydra
from hydra.core.global_hydra import GlobalHydra
from dora import hydra_main
logger = logging.getLogger(__name__)

In [2]:
from demucs.separate import Separator

device = "cuda" if torch.cuda.is_available() else "cpu"
separator = Separator(
    model="htdemucs",
    repo=None,
    device=device,
    shifts=1,
    overlap=0.25,
    split=True,
    segment=None,
    jobs=None,
    callback=print
)
segment = None
callback = None
length = None
samplerate = 44100
device

'cuda'

In [3]:
# Function to count the number of parameters of a torch model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
model_htdemucs = pretrained.get_model('htdemucs')
model_htdemucs.use_train_segment = False
teacher_model = model_htdemucs.models[0]
teacher_model.use_train_segment = False

In [5]:
student_model, teacher_model = kd_helper.get_student_teacher_models(partial_weight_copy=True)

Partial weights transferred successfully from the teacher to the student model.


In [6]:
print(f"{count_parameters(teacher_model):,} parameters in teacher model")
print(f"{count_parameters(student_model):,} parameters in student model")

41,984,456 parameters in teacher model
8,628,760 parameters in student model


In [7]:
audio_input = torch.randn(1, 2, 44100)  # Example input
# Forward pass through the model
teacher_start = time.time()
with torch.no_grad():
    teacher_separated_sources = teacher_model(audio_input)
teacher_end = time.time()
print("Time taken for teacher model: ", teacher_end - teacher_start)
print("Teacher model output shape: ", teacher_separated_sources.shape)
student_start = time.time()
with torch.no_grad():
    student_separated_sources = student_model(audio_input)
student_end = time.time()
print("Time taken for student model: ", student_end - student_start)
print("Student model output shape: ", student_separated_sources.shape)

Time taken for teacher model:  3.7313716411590576
Teacher model output shape:  torch.Size([1, 4, 2, 44100])
Time taken for student model:  0.15351533889770508
Student model output shape:  torch.Size([1, 4, 2, 44100])


In [8]:
next(student_model.parameters()).device

device(type='cpu')

In [9]:
def get_filtered_audio(file, method):
    wav = AudioFile(file).read(streams=0, samplerate=samplerate, channels=separator._audio_channels)
    original_length = wav.shape[1]
    if method[0] is None:
        return wav, original_length
    elif method[0] == "decimation_without_filtering":
        decimation_factor = method[1]
        wav = wav[:, ::decimation_factor]
        return wav, original_length
    elif method[0] == "decimation_with_butterworth_filter":
        cutoff, order, decimation_factor = method[1]
        nyquist = 0.5 * samplerate
        normal_cutoff = cutoff / nyquist
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        wav = filtfilt(b, a, wav, axis=1)
        wav = wav[:, ::decimation_factor]
        wav_tensor = torch.tensor(np.copy(wav), dtype=torch.float32)
        return wav_tensor, original_length
    elif method[0] == "decimation_with_chebyshev_filter":
        cutoff, order, ripple, decimation_factor = method[1]
        nyquist = 0.5 * samplerate
        normal_cutoff = cutoff / nyquist
        b, a = cheby1(order, ripple, normal_cutoff, btype='low', analog=False)
        wav = filtfilt(b, a, wav, axis=1)
        wav = wav[:, ::decimation_factor]
        wav_tensor = torch.tensor(np.copy(wav), dtype=torch.float32)
        return wav_tensor, original_length
    assert False, "Invalid method"

def interpolate_wav_file(wav, original_length):
    return resample(wav, original_length, axis=1)

def clean_up_out_wav(out, wav, original_length):
    wav = torch.tensor(resample(wav, original_length, axis=1))
    out = torch.tensor(resample(out, original_length, axis=3))
    return out, wav

In [10]:
# @track_emissions()
def run_separator_htdemucs(model, file, output_save_folder = "random_files", save_audio_flag=True, method=[None]):
    with torch.no_grad():
        os.makedirs(output_save_folder, exist_ok=True)
        wav, original_length = get_filtered_audio(file, method)
        ref = wav.mean(0)
        wav -= ref.mean()
        wav /= ref.std() + 1e-8
        mix = wav[None]
        # Assuming the rest of your code remains unchanged
        filename_format = "{stem}.{ext}"

        start_time = time.time()
        with torch.no_grad():
            out = model(mix)
        end_time = time.time()

        assert isinstance(out, torch.Tensor)
        out *= ref.std() + 1e-8
        out += ref.mean()
        wav *= ref.std() + 1e-8
        wav += ref.mean()
        out, wav = clean_up_out_wav(out, wav, original_length)
        separated = (wav, dict(zip(separator._model.sources, out[0])))[1]
        ext = "mp3"
        kwargs = {
            "samplerate": samplerate,
            "bitrate": 320,
            "clip": "rescale",
            "as_float": False,
            "bits_per_sample": 16,
        }
        last_ret = {}
        for stem, source in separated.items():
            stem_path = os.path.join(output_save_folder, filename_format.format(
                stem=stem,
                ext=ext,
            ))
            if save_audio_flag:
                save_audio(source, str(stem_path), **kwargs)
            else:
                last_ret[stem] = source
            # loaded_wav, _ = get_filtered_audio(stem_path, [None])
            # assert source.shape == loaded_wav.shape, f"{source.shape} != {loaded_wav.shape}"
        inference_time = end_time - start_time
        return inference_time, None, None, last_ret

In [11]:
run_separator_htdemucs(teacher_model, "my_test_short.mp4", method=[None])

(2.9280879497528076, None, None, {})

In [16]:
class Args:
    def __init__(self):
        self.seed = 42
        self.batch_size = 64
        self.epochs = 2
        # Dataset related arguments
        self.dset = self.DatasetArgs()
        # Optimization related arguments
        self.optim = self.OptimArgs()
        # Augmentation related arguments
        self.augment = self.AugmentArgs()
        # Testing related arguments
        self.test = self.Test()
        # Miscellaneous arguments
        self.misc = self.MiscArgs()
        self.sources = ['drums', 'bass', 'other', 'vocals']

    class DatasetArgs:
        def __init__(self):
            self.musdb = '/home/ubuntu/odml_final/distillation_demucs/MUSDBHQ'
            self.musdb_samplerate = 44100
            self.use_musdb = True
            self.wav = None  # path to custom wav dataset
            self.wav2 = None  # second custom wav dataset
            self.segment = 11
            self.shift = 1
            self.train_valid = False
            self.full_cv = True
            self.samplerate = 44100
            self.channels = 2
            self.normalize = True
            self.metadata = './metadata'
            self.sources = ['drums', 'bass', 'other', 'vocals']
            self.valid_samples = None  # valid dataset size
            self.backend = None

    class OptimArgs:
        def __init__(self):
            self.lr = 3e-4
            self.momentum = 0.9
            self.beta2 = 0.999
            self.loss = 'l1'  # l1 or mse
            self.optim = 'adam'
            self.weight_decay = 0
            self.clip_grad = 0

    class AugmentArgs:
        def __init__(self):
            self.shift_same = False
            self.repitch = self.Repitch()
            self.remix = self.Remix()
            self.scale = self.Scale()
            self.flip = True

        class Repitch:
            def __init__(self):
                self.proba = 0.2
                self.max_tempo = 12

        class Remix:
            def __init__(self):
                self.proba = 1
                self.group_size = 4
        
        class Scale:
            def __init__(self):
                self.proba = 1
                self.min = 0.25
                self.max = 1.25
        

    class Test:
        def __init__(self):
            self.save = False
            self.best = True
            self.workers = 2
            self.every = 5
            self.split = True
            self.shifts = 1
            self.overlap = 0.25
            self.sdr = True
            self.metric = 'loss'
            self.nonhq = None

    class MiscArgs:
        def __init__(self):
            # You can add any other miscellaneous arguments here if needed.
            self.show = False
            self.num_workers = 10
            self.num_prints = 4
            self.verbose = False

# Initialize args object with default values from the config file
args = Args()

# Accessing a parameter would be like this:
print(args.dset.musdb)

/home/ubuntu/odml_final/distillation_demucs/MUSDBHQ


In [21]:
import demucs.train
from kt_solver import KTSolver
from demucs.repitch import RepitchedWrapper


train_set, valid_set = demucs.train.get_datasets(args)
device = "cuda" if torch.cuda.is_available() else "cpu"


def get_my_solver(args, model_only=False):
    distrib.init()
    torch.manual_seed(args.seed)
    teacher_model, student_model = kd_helper.get_student_teacher_models(partial_weight_copy=True)
    if args.misc.show:
        mb = sum(p.numel() for p in teacher_model.parameters()) * 4 / 2**20
        print(f"Teacher model has {mb:.1f}MB")
        smb = sum(p.numel() for p in student_model.parameters()) * 4 / 2**20
        print(f"Student model has {smb:.1f}MB")
        if hasattr(teacher_model, "valid_length"):
            field = teacher_model.valid_length(1)
            print(f"Field: {field/args.dset.samplerate*1000:.1f}ms")
        sys.exit(0)

    teacher_model.to(device)
    student_model.to(device)
    
    optimizer = demucs.train.get_optimizer(student_model, args)
    
    assert args.batch_size % distrib.world_size == 0
    args.batch_size //= distrib.world_size
    
    if model_only:
        return KTSolver(None, student_model, teacher_model, optimizer, args)
    
    train_set, valid_set = demucs.train.get_datasets(args)
    
    if args.augment.repitch.proba:
        vocals = []
        if 'vocals' in args.dset.sources:
            vocals.append(args.sources.index('vocals'))
        else:
            logger.warning("No vocal source found")
        if args.augment.repitch.proba:
            train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch)
    
    logger.info("train/valid set size: %d %d", len(train_set), len(valid_set))
    train_loader = distrib.loader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.misc.num_workers, drop_last=True)
    if args.dset.full_cv:
        valid_loader = distrib.loader(
            valid_set, batch_size=1, shuffle=False,
            num_workers=args.misc.num_workers)
    else:
        valid_loader = distrib.loader(
            valid_set, batch_size=args.batch_size, shuffle=False,
            num_workers=args.misc.num_workers, drop_last=True)
    loaders = {"train": train_loader, "valid": valid_loader}

    # Construct Solver
    return KTSolver(loaders, student_model, teacher_model, optimizer, args)
    

In [22]:
train_set, valid_set = demucs.train.get_datasets(args)

In [20]:
train_set[0].shape

torch.Size([4, 2, 485100])