In [1]:
import json
%load_ext autoreload
%autoreload 2

In [2]:
# ! pip uninstall numba -y && pip install numba

In [3]:
import os
import sys

# Получаем абсолютный путь к корневой директории проекта (директория выше текущей)
root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Добавляем корневую директорию в sys.path
if root_path not in sys.path:
    sys.path.append(root_path)

In [4]:
from model_loaders import load_ss_model
import weightwatcher as ww
from matplotlib import pyplot as plt
from pipeline import separate_audio
import torch
from utils import parse_yaml
from models.clap_encoder import CLAP_Encoder
import IPython.display as ipd
from models.audiosep_lora_and_tuned_embeddings import AudioSepLoraAndTunedEmbeddings

2024-05-10 16:01:48.026308: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-10 16:01:48.026332: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-10 16:01:48.027145: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-10 16:01:48.031173: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
SS_CONFIG_PATH = '../config/audiosep_base.yaml'
CLAP_CKPT_PATH = '../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt'
AUDIOSEP_CKPT_PATH = '../checkpoint/audiosep_base_4M_steps.ckpt'
device = torch.device('cuda')
configs = parse_yaml(SS_CONFIG_PATH)

Подготавливаем несколько миксов

In [6]:
from pydub import AudioSegment
import json
import random

def mix_wav_files(files, output_path):
    mixed_sound = AudioSegment.from_file(files[0])

    for file in files[1:]:
        next_sound = AudioSegment.from_file(file)
        mixed_sound = mixed_sound.overlay(next_sound)

    mixed_sound.export(output_path, format='wav')

def select_random_elements(data, n_elements, seed=None):
    if seed is not None:
        random.seed(seed)

    n_elements = min(n_elements, len(data))

    selected_elements = random.sample(data, n_elements)
    return selected_elements


In [7]:
query_encoder = CLAP_Encoder(pretrained_path = CLAP_CKPT_PATH).eval().to(device)
base_model = load_ss_model(configs=configs, checkpoint_path=AUDIOSEP_CKPT_PATH, query_encoder=query_encoder).eval().to(device)

2024-05-10 16:01:52,552 - INFO - Loading HTSAT-base model config.
2024-05-10 16:01:54,886 - INFO - Loading pretrained HTSAT-base-roberta weights (../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt).


In [8]:
checkpoint_path = '../checkpoints/final/dota2/lora_embeddings/final.ckpt'

query_encoder_for_lora = CLAP_Encoder(pretrained_path = CLAP_CKPT_PATH).eval().to(device)
base_model_for_lora = load_ss_model(configs=configs, checkpoint_path=AUDIOSEP_CKPT_PATH, query_encoder=query_encoder_for_lora).eval().to(device)

lora_model = AudioSepLoraAndTunedEmbeddings.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    strict=False,
    pretrained_audiosep_model = base_model_for_lora,
    loss_function=None,
    waveform_mixer=None,
    lr_lambda_func=None
) \
    .eval() \
    .to(device)

merged_lora_model = lora_model.model.merge_and_unload()
merged_lora_model.query_encoder = lora_model.model.query_encoder

2024-05-10 16:01:59,995 - INFO - Loading HTSAT-base model config.
2024-05-10 16:02:01,330 - INFO - Loading pretrained HTSAT-base-roberta weights (../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt).


In [18]:
import shutil
from ipywidgets import widgets


def separate_and_visualize(mix_cnt, seeds):
    seed_mixtures = {}

    with open('../datafiles/dota2.val.json') as file:
        parsed_json = json.load(file)['data']
        for seed in seeds:
            random_mixes = select_random_elements(parsed_json, mix_cnt, seed)
            seed_mixtures[seed] = {
                'test_files': [f'../{mix["wav"]}' for mix in random_mixes],
                'classes': [mix['caption'] for mix in random_mixes]
            }

    mixture_dir = 'dota2_mixtures'
    if not os.path.exists(mixture_dir):
        os.makedirs(mixture_dir)

    for seed in seeds:
        mixture_path = f'{mixture_dir}/{mix_cnt}_{seed}.wav'
        mix_wav_files(seed_mixtures[seed]['test_files'], mixture_path)
        seed_mixtures[seed]['mixture_path'] = mixture_path

    audio_widgets_list = []

    for seed in seeds:
        file = seed_mixtures[seed]['mixture_path']
        classes = seed_mixtures[seed]['classes']
        audio_widgets = {cls: {} for cls in classes + ['original']}

        filename = file.split(os.sep)[-1]
        audio_widget = widgets.Audio(value=open(file, "rb").read(), format="wav", controls=True, autoplay=False)
        audio_widgets['original'][filename] = audio_widget
        output_dir = f'../separation_result/dota2/{filename.split(".wav")[0]}/'
        os.makedirs(os.path.dirname(output_dir), exist_ok=True)
        for cls in classes:
            output_file = os.path.join(output_dir, f'{cls}.wav')
            separate_audio(merged_lora_model, file, cls, output_file, device, use_chunk=False)
            audio_widget = widgets.Audio(value=open(output_file, "rb").read(), format="wav", controls=True, autoplay=False)
            audio_widgets[cls][filename] = audio_widget

        audio_widgets_list.append(audio_widgets)
        shutil.copyfile(file, f'{output_dir}/original.wav')

    from utils import plot_separation_result

    for audio_widgets in audio_widgets_list:
        plot_separation_result(audio_widgets)

# Попробуем выделить микс из реплик двух героев

In [19]:
mix_cnt = 2
seeds = [678, 789, 909, 345]
separate_and_visualize(mix_cnt, seeds)

Separating audio from [dota2_mixtures/2_678.wav] with textual query: [pugna]
Separated audio written to [../separation_result/dota2/2_678/pugna.wav]
Separating audio from [dota2_mixtures/2_678.wav] with textual query: [beastmaster]
Separated audio written to [../separation_result/dota2/2_678/beastmaster.wav]
Separating audio from [dota2_mixtures/2_789.wav] with textual query: [enchantress]
Separated audio written to [../separation_result/dota2/2_789/enchantress.wav]
Separating audio from [dota2_mixtures/2_789.wav] with textual query: [medusa]
Separated audio written to [../separation_result/dota2/2_789/medusa.wav]
Separating audio from [dota2_mixtures/2_909.wav] with textual query: [pudge]
Separated audio written to [../separation_result/dota2/2_909/pudge.wav]
Separating audio from [dota2_mixtures/2_909.wav] with textual query: [invoker]
Separated audio written to [../separation_result/dota2/2_909/invoker.wav]
Separating audio from [dota2_mixtures/2_345.wav] with textual query: [wraith

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='pugna', layout=Layout(margin='0 10px 0 10px')), La…

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='enchantress', layout=Layout(margin='0 10px 0 10px'…

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='pudge', layout=Layout(margin='0 10px 0 10px')), La…

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='wraith king', layout=Layout(margin='0 10px 0 10px'…

Качество выделения неплохое, хотя в каждой из аудиозаписей есть призвуки другого персонажа.

# Теперь попробуем выделить микс из реплик трех героев, эта задача сложнее

In [20]:
mix_cnt = 3
seeds = [678, 789, 909, 345]
separate_and_visualize(mix_cnt, seeds)

Separating audio from [dota2_mixtures/3_678.wav] with textual query: [pugna]
Separated audio written to [../separation_result/dota2/3_678/pugna.wav]
Separating audio from [dota2_mixtures/3_678.wav] with textual query: [beastmaster]
Separated audio written to [../separation_result/dota2/3_678/beastmaster.wav]
Separating audio from [dota2_mixtures/3_678.wav] with textual query: [kunkka]
Separated audio written to [../separation_result/dota2/3_678/kunkka.wav]
Separating audio from [dota2_mixtures/3_789.wav] with textual query: [enchantress]
Separated audio written to [../separation_result/dota2/3_789/enchantress.wav]
Separating audio from [dota2_mixtures/3_789.wav] with textual query: [medusa]
Separated audio written to [../separation_result/dota2/3_789/medusa.wav]
Separating audio from [dota2_mixtures/3_789.wav] with textual query: [razor]
Separated audio written to [../separation_result/dota2/3_789/razor.wav]
Separating audio from [dota2_mixtures/3_909.wav] with textual query: [pudge]
S

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='pugna', layout=Layout(margin='0 10px 0 10px')), La…

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='enchantress', layout=Layout(margin='0 10px 0 10px'…

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='pudge', layout=Layout(margin='0 10px 0 10px')), La…

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='wraith king', layout=Layout(margin='0 10px 0 10px'…

На трех миксах качество намного хуже чем на двух из-за большего количества шума. Особенно это проявляется на миксах, где голоса героев очень близки по тембру.

Исследуя графики понятно, что после обучения на 50 эпохах метрики продолжали расти. Вероятно, обучив модель на большем количестве эпох с другим lr шедулером (например cosine) качество выделения станет лучше.