## Результат модели с lora адаптером на separation model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# ! pip uninstall numba -y && pip install numba

In [3]:
import os
import sys

# Получаем абсолютный путь к корневой директории проекта (директория выше текущей)
root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Добавляем корневую директорию в sys.path
if root_path not in sys.path:
    sys.path.append(root_path)

In [4]:
from model_loaders import load_ss_model
import weightwatcher as ww
from matplotlib import pyplot as plt
from pipeline import separate_audio
import torch
from utils import parse_yaml
from models.clap_encoder import CLAP_Encoder
import IPython.display as ipd
from models.audiosep_lora import AudioSepLora
import pandas as pd
from utils import plot_separation_result

2024-05-04 19:29:00.039128: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-04 19:29:00.039192: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-04 19:29:00.054894: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-04 19:29:00.091029: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
SS_CONFIG_PATH = '../config/audiosep_base.yaml'
CLAP_CKPT_PATH = '../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt'
AUDIOSEP_CKPT_PATH = '../checkpoint/audiosep_base_4M_steps.ckpt'
classes = ["bass", "drums", "vocals", 'other musical instruments']
device = torch.device('cuda')
configs = parse_yaml(SS_CONFIG_PATH)

In [6]:
query_encoder = CLAP_Encoder(pretrained_path = CLAP_CKPT_PATH).eval().to(device)
base_model = load_ss_model(configs=configs, checkpoint_path=AUDIOSEP_CKPT_PATH, query_encoder=query_encoder).eval().to(device)

2024-05-04 19:29:04,961 - INFO - Loading HTSAT-base model config.
2024-05-04 19:29:07,204 - INFO - Loading pretrained HTSAT-base-roberta weights (../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt).


In [10]:
checkpoint_path = '../checkpoints/final/musdb18/lora/final.ckpt'

query_encoder_for_lora = CLAP_Encoder(pretrained_path = CLAP_CKPT_PATH).eval().to(device)
base_model_for_lora = load_ss_model(configs=configs, checkpoint_path=AUDIOSEP_CKPT_PATH, query_encoder=query_encoder_for_lora).eval().to(device)

lora_model = AudioSepLora.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    strict=False,
    pretrained_audiosep_model = base_model_for_lora,
    loss_function=None,
    waveform_mixer=None,
    lr_lambda_func=None
) \
    .eval() \
    .to(device)

merged_lora_model = lora_model.model.merge_and_unload()
merged_lora_model.query_encoder = lora_model.model.query_encoder

2024-05-04 19:32:54,639 - INFO - Loading HTSAT-base model config.
2024-05-04 19:32:55,997 - INFO - Loading pretrained HTSAT-base-roberta weights (../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt).


In [23]:
from ipywidgets import widgets

test_files = [f'../audios/{filename}' for filename in os.listdir('../audios')]

audio_widgets = {cls: {} for cls in classes + ['original']}

for file in test_files:
    filename = file.split(os.sep)[-1]
    audio_widget = widgets.Audio(value=open(file, "rb").read(), format="wav", controls=True, autoplay=False)
    audio_widgets['original'][filename] = audio_widget

    for cls in classes:
        output_dir = '../separation_result/audiosep_lora/'
        os.makedirs(os.path.dirname(output_dir), exist_ok=True)
        output_file = os.path.join(output_dir, f'{filename}_{cls}.wav')
        separate_audio(merged_lora_model, file, cls, output_file, device, use_chunk=True)
        audio_widget = widgets.Audio(value=open(output_file, "rb").read(), format="wav", controls=True, autoplay=False)
        audio_widgets[cls][filename] = audio_widget

Separating audio from [../audios/ANIKV - слов больше нет - Lørean Edit.wav] with textual query: [bass]
Separated audio written to [../separation_result/audiosep_lora/ANIKV - слов больше нет - Lørean Edit.wav_bass.wav]
Separating audio from [../audios/ANIKV - слов больше нет - Lørean Edit.wav] with textual query: [drums]
Separated audio written to [../separation_result/audiosep_lora/ANIKV - слов больше нет - Lørean Edit.wav_drums.wav]
Separating audio from [../audios/ANIKV - слов больше нет - Lørean Edit.wav] with textual query: [vocals]
Separated audio written to [../separation_result/audiosep_lora/ANIKV - слов больше нет - Lørean Edit.wav_vocals.wav]
Separating audio from [../audios/ANIKV - слов больше нет - Lørean Edit.wav] with textual query: [other musical instruments]
Separated audio written to [../separation_result/audiosep_lora/ANIKV - слов больше нет - Lørean Edit.wav_other musical instruments.wav]
Separating audio from [../audios/CREAM SODA - Никаких Больше Вечеринок.wav] with

In [24]:
plot_separation_result(audio_widgets)

HTML(value='<style>.widget-label { font-size: 16px; font-weight: bold; }</style>')

VBox(children=(HBox(children=(Label(value=''), Label(value='bass', layout=Layout(margin='0 10px 0 10px')), Lab…