In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# ! pip uninstall numba -y && pip install numba

In [None]:
import os
import sys

# Получаем абсолютный путь к корневой директории проекта (директория выше текущей)
root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Добавляем корневую директорию в sys.path
if root_path not in sys.path:
    sys.path.append(root_path)

In [None]:
from model_loaders import load_ss_model
import weightwatcher as ww
from matplotlib import pyplot as plt
from pipeline import separate_audio
import torch
from utils import parse_yaml
from models.clap_encoder import CLAP_Encoder
import IPython.display as ipd
from models.audiosep_lora_and_tuned_embeddings import AudioSepLoraAndTunedEmbeddings

In [None]:
SS_CONFIG_PATH = '../config/audiosep_base.yaml'
CLAP_CKPT_PATH = '../checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt'
AUDIOSEP_CKPT_PATH = '../checkpoint/audiosep_base_4M_steps.ckpt'
classes = ["bass", "drums", "vocals"]
device = torch.device('cuda')
configs = parse_yaml(SS_CONFIG_PATH)

In [None]:
query_encoder = CLAP_Encoder(pretrained_path = CLAP_CKPT_PATH).eval().to(device)
base_model = load_ss_model(configs=configs, checkpoint_path=AUDIOSEP_CKPT_PATH, query_encoder=query_encoder).eval().to(device)

In [None]:
checkpoint_path = '../checkpoints/train_audiosep_lora_and_tuned_embeddings/audiosep_lora_and_tuned_embeddings_musdb18,args=logs_per_class=True, dropout=0.1,timestamp=1712871001.7698753/epoch=19.ckpt'

query_encoder_for_lora = CLAP_Encoder(pretrained_path = CLAP_CKPT_PATH).eval().to(device)
base_model_for_lora = load_ss_model(configs=configs, checkpoint_path=AUDIOSEP_CKPT_PATH, query_encoder=query_encoder_for_lora).eval().to(device)

lora_model = AudioSepLoraAndTunedEmbeddings.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    strict=False,
    pretrained_audiosep_model = base_model_for_lora,
    loss_function=None,
    waveform_mixer=None,
    lr_lambda_func=None
) \
    .eval() \
    .to(device)

merged_lora_model = lora_model.model.merge_and_unload()

In [None]:
test_files = ['../evaluation/data/musdb18/test/Lyndsey Ollard - Catching Up/mixture.wav']

for file in test_files:
    display(ipd.Audio(file))
    filename = file.split(os.sep)[-1]
    for cls in classes:
        output_file = f'../separation_result/audiosep_lora_and_tuned_embedings_{filename}_{cls}.wav'
        separate_audio(merged_lora_model, file, cls, output_file, device, use_chunk=True)
        display(ipd.Audio(output_file))

In [None]:
plt.rcParams["figure.figsize"] = (20,3)

In [None]:
base_details, base_summary = describe_weights(base_model.ss_model)
plot_hist(base_details)
print(base_summary)

In [None]:
lora_details, lora_summary = describe_weights(merged_lora_model.ss_model)
plot_hist(lora_details)
print(lora_summary)

In [None]:
watcher = ww.WeightWatcher()
avg_dW, avg_db, distances = watcher.distances(base_model.ss_model, merged_lora_model.ss_model)
avg_dW, avg_db

In [None]:
distances

In [None]:
print(merged_lora_model)

In [None]:
conv_distances = distances[distances['name']=='Conv2d']

In [None]:
import pandas as pd

df_split = conv_distances['longname'].str.split('.', expand=True)

# Объединение разделенных частей с исходным DataFrame
df_expanded = pd.concat([df_split, conv_distances['delta_W']], axis=1)

# Группировка и суммирование значений
# Уровень группировки будет увеличиваться на каждом шаге
for level in range(df_expanded.shape[1] - 1):  # Исключаем колонку 'value'
    grouped_df = df_expanded.groupby(list(range(level + 1))).sum().reset_index()
    grouped_df = grouped_df.sort_values(by='delta_W', ascending=False)
    print(f"Группировка по уровню {level + 1}:\n", grouped_df, "\n")


In [None]:
base_details[base_details['warning']!= '']

In [None]:
lora_details[lora_details['warning']!= '']

In [None]:
import numpy as np

watcher = ww.WeightWatcher()
layer1_iterator = watcher.make_layer_iterator(model=base_model.ss_model)
layer2_iterator = watcher.make_layer_iterator(model=merged_lora_model.ss_model)

metrics_df = pd.DataFrame(columns=['layer_name', 'base_norm', 'lora_norm', 'norm_of_diff', 'diff_of_norms'])
for layer1, layer2 in zip(layer1_iterator, layer2_iterator):
    if layer1.name != 'Conv2d':
        continue
    if layer1.longname != layer2.longname:
        raise Exception('layer names are not equal!')

    has_weights1, W1, has_biases1, b1  = layer1.get_weights_and_biases()
    W1 = W1.astype(np.float32)
    has_weights2, W2, has_biases2, b2  = layer2.get_weights_and_biases()
    W2 = W2.astype(np.float32)

    if has_weights1 and has_weights2:
        norm1 = np.linalg.norm(W1)
        norm2 = np.linalg.norm(W2)
        metrics_df = metrics_df.append({'layer_name': layer1.longname, 'base_norm': norm1, 'lora_norm': norm2, 'norm_of_diff': watcher.matrix_distance(W1, W2), 'diff_of_norms': np.linalg.norm(W1) - np.linalg.norm(W2)}, ignore_index=True)

metrics_df

In [None]:
www = torch.load('../checkpoints/train_audiosep_lora_and_tuned_embeddings/audiosep_lora_and_tuned_embeddings_musdb18,timestamp=1710451587.1702235/epoch=29.ckpt')
print([v for k, v in www['state_dict'].items() if 'tuned_embedding_layer' in k.lower()])

In [None]:
test_details, test_summary = describe_weights(www)
plot_hist(test_details)
print(test_summary)
test_details