### Импорты

In [1]:
!pip install musdb -q
# !pip uninstall -y ffmpeg
# !pip uninstall -y ffmpeg-python]\[''']
# !pip install -q ffmpeg-python

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.2/963.2 kB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import musdb

In [3]:
# https://zenodo.org/records/1117372 - датасет

In [4]:
import torch
import torch.nn as nn

import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

from torch.utils.data import Dataset, DataLoader

import musdb

from tqdm import tqdm

In [5]:
import stempeg
import os

In [6]:
import random
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7875c4149f90>

In [7]:
from sklearn.model_selection import train_test_split

### Необходимые классы

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
# https://github.com/sigsep/sigsep-mus-db - примеры работы с musdb

In [10]:
class MusDBDataset(Dataset):
  def __init__(self, mus_dataset
               , current_sample_rate
               , target_sample_rate
               , num_input_seconds
               , num_output_seconds
               , one_bit_seconds
               , device = device):
    super().__init__()

    self.mus_dataset = mus_dataset
    self.current_sample_rate = current_sample_rate
    self.target_sample_rate = target_sample_rate
    self.num_input_seconds = num_input_seconds
    self.num_output_seconds = num_output_seconds
    self.one_bit_seconds = one_bit_seconds
    self.num_bits_in_second = 1 / self.one_bit_seconds # избегаем ошибку округления
    self.device = device

    self.resampler = T.Resample(self.current_sample_rate, self.target_sample_rate)

    self.durations = list(map(lambda track: track.duration, self.mus_dataset))
    self.num_possible_samples_from_track = list(
        map(
            lambda duration: int((duration - self.num_input_seconds \
                              - self.num_output_seconds) \
            * self.num_bits_in_second) + 1,
            self.durations)
        )
    self.start_seconds = []
    for idx, num_samples in enumerate(self.num_possible_samples_from_track):
      self.start_seconds.extend([(
          idx, i/self.num_bits_in_second
          ) for i in range(int(num_samples))])

    self.cached_tracks = {}
    self.cached_parts = {}

    self.expected_shape_X = int(self.num_input_seconds*self.target_sample_rate)
    self.expected_shape_y = int(self.num_output_seconds*self.target_sample_rate)


  def _change_sample_rate(self, array, current_sample_rate, target_sample_rate):
    array = torch.tensor(array, dtype=torch.float32)

    if current_sample_rate!=target_sample_rate:
      array = self.resampler(array)

    return array.to(device)

  def __len__(self):
    return sum(self.num_possible_samples_from_track)

  def _pad(self, item, expectation):
    return torch.nn.functional.pad(item, (0, expectation-item.shape[0]))

  def _prepare_track(self, track_num, start):

    if self.cached_tracks.get(track_num, 'no') == 'no':

      track = self.mus_dataset[track_num]



      vocals = self._change_sample_rate(track.targets['vocals'].audio.mean(axis=-1)
                                          , self.current_sample_rate
                                          , self.target_sample_rate)
      drums = self._change_sample_rate(track.targets['drums'].audio.mean(axis=-1)
                                      , self.current_sample_rate
                                      , self.target_sample_rate)
      bass = self._change_sample_rate(track.targets['bass'].audio.mean(axis=-1)
                                      , self.current_sample_rate
                                      , self.target_sample_rate)
      other = self._change_sample_rate(track.targets['other'].audio.mean(axis=-1)
                                      , self.current_sample_rate
                                      , self.target_sample_rate)
      all = self._change_sample_rate(track.audio.mean(axis=-1)
                                      , self.current_sample_rate
                                      , self.target_sample_rate)

      self.cached_tracks[track_num] = {'vocals':vocals
                                       , 'drums':drums
                                       , 'bass':bass
                                       , 'other':other
                                       , 'all':all}

    track = self.cached_tracks[track_num]
    vocals = track['vocals']
    drums = track['drums']
    bass = track['bass']
    other = track['other']
    all = track['all']

    end_X = start + self.num_input_seconds
    end_y = end_X + self.num_output_seconds

    start_X = start * self.target_sample_rate
    start_X = int(start_X)

    end_X *= self.target_sample_rate
    end_X = int(end_X)

    end_y *= self.target_sample_rate
    end_y = int(end_y)

    vocals_X = self._pad(vocals[start_X:end_X], self.expected_shape_X)
    vocals_y = self._pad(vocals[end_X:end_y], self.expected_shape_y)



    drums_X = self._pad(drums[start_X:end_X], self.expected_shape_X)
    drums_y = self._pad(drums[end_X:end_y], self.expected_shape_y)



    bass_X = self._pad(bass[start_X:end_X], self.expected_shape_X)
    bass_y = self._pad(bass[end_X:end_y], self.expected_shape_y)



    other_X = self._pad(other[start_X:end_X], self.expected_shape_X)
    other_y = self._pad(other[end_X:end_y], self.expected_shape_y)



    all_X = self._pad(other[start_X:end_X], self.expected_shape_X)
    all_y = self._pad(other[end_X:end_y], self.expected_shape_y)



    return {'vocals':(vocals_X[:self.expected_shape_X]
                      , vocals_y[:self.expected_shape_y]),
                            'drums':(drums_X[:self.expected_shape_X]
                                     , drums_y)[:self.expected_shape_y],
                            'bass':(bass_X[:self.expected_shape_X]
                                    , bass_y[:self.expected_shape_y]),
                            'other':(other[:self.expected_shape_X]
                                     , other_y)[:self.expected_shape_y],
                            'all':(all_X[:self.expected_shape_X]
                                   , all_y[:self.expected_shape_y])}



  def __getitem__(self, idx):

    # return self.final_list[idx]
    # if self.cached_parts.get(idx, 'no') == 'no':
    track_num, start_timestep = self.start_seconds[idx]
    preprocessed_data = self._prepare_track(track_num, start_timestep)
    # self.cached_parts[idx] = preprocessed_data

    return preprocessed_data


# Мок модели

In [11]:
class MockNet(nn.Module):
  def __init__(self, input_size, output_size):
    super().__init__()
    self.input_size = input_size
    self.output_size = output_size
    self.linear = nn.Linear(self.input_size, self.output_size)
    self.activ = nn.Sigmoid()

  def forward(self, X):
    return self.activ(self.linear(X))

### Конфиг


In [12]:
config = {'current_sample_rate':22050,
'target_sample_rate': 22050,
'num_input_seconds': 5,
'num_output_seconds': 0.5,
'one_bit_seconds': 0.2,
'n_fft': 1024,
'batch_size': 8,
'learning_rate': 5e-5,
          'mel_bins': 128,
    'dropout_rate': 0.2,
    'use_bidirectional': False  # Recommended True for audio sequences
          , 'test_size': 0.2
          , 'num_epochs': 10
          , 'spectral_backward': False
          , 'use_7s': True
          , 'target_log': True
          }

In [13]:
config['hop_length']= config['n_fft']//2
config['sequence_length_input'] = int(1 + (config['num_input_seconds']*config['target_sample_rate'] - config['n_fft']) // config['hop_length'] + 2)
config['sequence_length_output'] = int(1 + (config['num_output_seconds']*config['target_sample_rate'] - config['n_fft']) // config['hop_length'] + 2)

In [14]:
import numpy as np
from dataclasses import dataclass, field

In [15]:
@dataclass
class Stem:
    audio: np.ndarray = None

@dataclass
class Track:
    audio: np.ndarray = None
    targets: dict = field(default_factory=lambda: {
'vocals': Stem(),
'drums': Stem(),
'bass': Stem(),
'other': Stem()
})

In [16]:
# raise ValueError

In [17]:
import joblib

In [18]:
if config['use_7s']:
  mus = musdb.DB(download=True, sample_rate=config['target_sample_rate'])

else:
  !gdown 1_6ivqpFCdflkd7JMyyx6_PMT1IaThEPu
  !mkdir ./musdb18
  !tar -xvzf ./musdb18.tar.gz -C ./
  !mkdir ./musdb18/train_extracted
  !mkdir ./musdb18/test_extracted
  stems_map = {0: 'mixture', 1: 'drums', 2: 'bass', 3: 'other', 4: 'vocals'}
  mus = []
  train_folder = os.listdir('./musdb18/train')
  test_folder = os.listdir('./musdb18/test')
  for filename in tqdm(train_folder):
    try:
      if not filename.startswith('_'):
        track = Track()

        for i in range(5):
          audio, _ = stempeg.read_stems('./musdb18/train/'+filename,
                                        stem_id=i,
                      sample_rate=config['target_sample_rate'],
                      ffmpeg_format="s16le"
                  )

          audio = audio.astype('float16')
          if audio.shape[-1] == 2:
            audio = audio.mean(-1).reshape(-1, 1)
          if i == 0:
            track.audio = audio
          else:
            track.targets[stems_map[i]].audio = audio
          track.duration = audio.shape[0] / config['target_sample_rate']

        vocals = track.targets['vocals'].audio.mean(axis=-1)
        drums = track.targets['drums'].audio.mean(axis=-1)
        bass = track.targets['bass'].audio.mean(axis=-1)
        other = track.targets['other'].audio.mean(axis=-1)
        all = track.audio.mean(axis=-1)
        with open('./musdb18/train_extracted/'+filename+'.'+str(round(track.duration, 2))+'.pkl', 'wb') as f:
          joblib.dump({'vocals':vocals
                                        , 'drums':drums
                                        , 'bass':bass
                                        , 'other':other
                                        , 'all':all}, f)

    except Exception as E:
      # print(E)
      continue




  mus_test = []
  for filename in tqdm(test_folder):
    try:
      if not filename.startswith('._'):
        track = Track()

        for i in range(5):
          audio, _ = stempeg.read_stems('./musdb18/test/'+filename,
                                        stem_id=i,
                      sample_rate=config['target_sample_rate'],
                      ffmpeg_format="s16le"
                  )

          audio = audio.astype('float16')
          if audio.shape[-1] == 2:
            audio = audio.mean(-1)
          if i == 0:
            track.audio = audio
          else:
            track.targets[stems_map[i]].audio = audio
          track.duration = audio.shape[0] / config['target_sample_rate']
        vocals = track.targets['vocals'].audio.mean(axis=-1)
        drums = track.targets['drums'].audio.mean(axis=-1)
        bass = track.targets['bass'].audio.mean(axis=-1)
        other = track.targets['other'].audio.mean(axis=-1)
        all = track.audio.mean(axis=-1)
        with open('./musdb18/test_extracted/'+filename+'.'+str(round(track.duration, 2))+'.pkl', 'wb') as f:
          joblib.dump({'vocals':vocals
                                        , 'drums':drums
                                        , 'bass':bass
                                        , 'other':other
                                        , 'all':all}, f)
    except:
      continue

Downloading MUSDB 7s Sample Dataset to /root/MUSDB18/MUSDB18-7...


100%|██████████| 140M/140M [00:01<00:00, 118MB/s]


In [19]:
# for filename in tqdm(test_folder):

#   if not filename.startswith('._'):
#     track = Track()

#     for i in range(5):
#       audio, _ = stempeg.read_stems('./musdb18/test/'+filename,
#                                     stem_id=i,
#                   sample_rate=config['target_sample_rate'],
#                   ffmpeg_format="s16le"
#               )

#       audio = audio.astype('float16')
#       if audio.shape[-1] == 2:
#         audio = audio.mean(-1)
#       if i == 0:
#         track.audio = audio
#       else:
#         track.targets[stems_map[i]].audio = audio
#       track.duration = audio.shape[0] / config['target_sample_rate']
#     vocals = track.targets['vocals'].audio.mean(axis=-1)
#     drums = track.targets['drums'].audio.mean(axis=-1)
#     bass = track.targets['bass'].audio.mean(axis=-1)
#     other = track.targets['other'].audio.mean(axis=-1)
#     all = track.audio.mean(axis=-1)
#     with open('./musdb18/test_extracted/'+filename+'.'+str(round(track.duration, 2))+'.pkl', 'wb') as f:
#       joblib.dump({'vocals':vocals
#                                     , 'drums':drums
#                                     , 'bass':bass
#                                     , 'other':other
#                                     , 'all':all}, f)

In [20]:
import gc

In [21]:
class MusDBDatasetCached(Dataset):
  def __init__(self, mus_dataset_path
               , current_sample_rate
               , target_sample_rate
               , num_input_seconds
               , num_output_seconds
               , one_bit_seconds
               , device = device):
    super().__init__()

    self.mus_dataset_path = mus_dataset_path
    self.mus_dataset_filenames = os.listdir(mus_dataset_path)
    self.current_sample_rate = current_sample_rate
    self.target_sample_rate = target_sample_rate
    self.num_input_seconds = num_input_seconds
    self.num_output_seconds = num_output_seconds
    self.one_bit_seconds = one_bit_seconds
    self.num_bits_in_second = 1 / self.one_bit_seconds # избегаем ошибку округления
    self.device = device

    self.durations = list(map(lambda filename: float(filename.split('.')[-3]+'.'+filename.split('.')[-2]), self.mus_dataset_filenames))
    self.num_possible_samples_from_track = list(
        map(
            lambda duration: int((duration - self.num_input_seconds \
                              - self.num_output_seconds) \
            * self.num_bits_in_second) + 1,
            self.durations)
        )
    self.start_seconds = []
    for idx, num_samples in enumerate(self.num_possible_samples_from_track):
      self.start_seconds.extend([(
          idx, i/self.num_bits_in_second
          ) for i in range(int(num_samples))])

    self.cached_parts = {}

    self.expected_shape_X = int(self.num_input_seconds*self.target_sample_rate)
    self.expected_shape_y = int(self.num_output_seconds*self.target_sample_rate)


  def __len__(self):
    return sum(self.num_possible_samples_from_track)

  def _pad(self, item, expectation):
    return torch.nn.functional.pad(item, (0, expectation-item.shape[0]))

  def _prepare_track(self, track_num, start):

    with open(os.path.join(self.mus_dataset_path
                           , self.mus_dataset_filenames[track_num]), 'rb') as f:
      track = joblib.load(f)
    vocals = torch.tensor(track['vocals'], device=self.device)
    drums = torch.tensor(track['drums'], device=self.device)
    bass = torch.tensor(track['bass'], device=self.device)
    other = torch.tensor(track['other'], device=self.device)
    all = torch.tensor(track['all'], device=self.device)

    end_X = start + self.num_input_seconds
    end_y = end_X + self.num_output_seconds

    start_X = start * self.target_sample_rate
    start_X = int(start_X)

    end_X *= self.target_sample_rate
    end_X = int(end_X)

    end_y *= self.target_sample_rate
    end_y = int(end_y)

    vocals_X = self._pad(vocals[start_X:end_X], self.expected_shape_X)
    vocals_y = self._pad(vocals[end_X:end_y], self.expected_shape_y)


    drums_X = self._pad(drums[start_X:end_X], self.expected_shape_X)
    drums_y = self._pad(drums[end_X:end_y], self.expected_shape_y)


    bass_X = self._pad(bass[start_X:end_X], self.expected_shape_X)
    bass_y = self._pad(bass[end_X:end_y], self.expected_shape_y)



    other_X = self._pad(other[start_X:end_X], self.expected_shape_X)
    other_y = self._pad(other[end_X:end_y], self.expected_shape_y)



    all_X = self._pad(other[start_X:end_X], self.expected_shape_X)
    all_y = self._pad(other[end_X:end_y], self.expected_shape_y)



    return {'vocals':(vocals_X[:self.expected_shape_X]
                      , vocals_y[:self.expected_shape_y]),
                            'drums':(drums_X[:self.expected_shape_X]
                                     , drums_y)[:self.expected_shape_y],
                            'bass':(bass_X[:self.expected_shape_X]
                                    , bass_y[:self.expected_shape_y]),
                            'other':(other[:self.expected_shape_X]
                                     , other_y)[:self.expected_shape_y],
                            'all':(all_X[:self.expected_shape_X]
                                   , all_y[:self.expected_shape_y])}



  def __getitem__(self, idx):

    # return self.final_list[idx]
    # if self.cached_parts.get(idx, 'no') == 'no':
    track_num, start_timestep = self.start_seconds[idx]
    preprocessed_data = self._prepare_track(track_num, start_timestep)
    # self.cached_parts[idx] = preprocessed_data

    return preprocessed_data


In [22]:
if config['use_7s']:
  mus_train, mus_test = train_test_split(mus, test_size=config['test_size'], random_state=42)

  musdb_train = MusDBDataset(mus_train
                      , current_sample_rate=config['target_sample_rate']
                      , target_sample_rate=config['target_sample_rate']
                      , num_input_seconds=config['num_input_seconds']
                      , num_output_seconds=config['num_output_seconds']
                      , one_bit_seconds=config['one_bit_seconds']
                      , device=device)
  musdb_test = MusDBDataset(mus_test
                      , current_sample_rate=config['target_sample_rate']
                      , target_sample_rate=config['target_sample_rate']
                      , num_input_seconds=config['num_input_seconds']
                      , num_output_seconds=config['num_output_seconds']
                      , one_bit_seconds=config['one_bit_seconds']
                      , device=device)

else:

  musdb_train = MusDBDatasetCached('./musdb18/train_extracted'
  , current_sample_rate=config['target_sample_rate']
                      , target_sample_rate=config['target_sample_rate']
                      , num_input_seconds=config['num_input_seconds']
                      , num_output_seconds=config['num_output_seconds']
                      , one_bit_seconds=config['one_bit_seconds']
                      , device=device)
  musdb_test = MusDBDatasetCached('./musdb18/test_extracted'
  , current_sample_rate=config['target_sample_rate']
                      , target_sample_rate=config['target_sample_rate']
                      , num_input_seconds=config['num_input_seconds']
                      , num_output_seconds=config['num_output_seconds']
                      , one_bit_seconds=config['one_bit_seconds']
                      , device=device)

In [23]:
mus_dl_train = DataLoader(musdb_train, batch_size=config['batch_size'], shuffle=True)
mus_dl_test = DataLoader(musdb_test, batch_size=config['batch_size'], shuffle=True)

In [24]:
del mus
gc.collect()


0

# Метрики

Одной хорошей нет, нужен набор.

- Типы метрик по объекту оценки: Fidelity (качество аудиодорожки), Musicality (музыкальные характеристики аудиодорожки).
- Типы метрик по необходимости аудио: для мелспектрограмм, для аудио
- Типы метрик по оценке распределений: достаточно одного трека, необходим набор треков
- Как оценивать: автоматически, с использованием человеческой оценки

Ниже описываются только автоматические метрики для Fidelity.

Fidelity: \
1/ Frechet Audio Distance - для аудио, для распределения. Считается на эмбеддингах аудио (эвклидово расстояние центроидов эмбеддингов + след матрицы ковариаций). Есть нюансы с тем, какую модель эмбеддингов использовать: самый простой вариант - VGGish, наибольшую корреляцию с человеческой оценкой показывает на CLAP (https://arxiv.org/pdf/2506.19085). Есть расширение FAD-infinity для коррекции на размер батча. \
2/ CosineSim для эмбеддингов - аудио, отдельные треки/дорожки. Берём текстовые эмбеддинги CLAP для сгенерированного и эталонного аудио, считаем расстояние.\
3/ Reconstruction Loss - мелки/аудио, треки/дорожки. Прямолинейно. Считаем MSE на значения сгенерированной и эталонной спектрограмм. \
4/ Vendi Score (https://openreview.net/pdf?id=g97OHbQyk1) - аудио, распределение треков. Энтропия собственных значений матрицы схожести  эмбеддингов сгенерированной музыки. Метрика разнообразия. \
Musicality :\
Для мелок используются спектральные статистики. Центроид - оценка, какие частоты преобладают (высокие или низкие). Спектральная ширина - ширина спектра вокруг центроида (чем шире, тем сложнее тембр). Спектральная плоскость - Показывает, насколько спектр похож на шум (плоский) или на тон (пикообразный). Спектральная энтропия - насколько равномерно распределен сигнал по частотам (шум или не шум).


Также есть ритмические и тембральные характеристики, но они считаются на аудио (не на мелспектрограммах) и спектральные статистики их могут частично компенсировать.


Подходы с обучением дополнительных моделей:
- по мелспектрограммам предсказывать метрики, которые считаются только на аудио или на эмбеддингах, и использовать как дополнительные таргеты при обучении.
- Wasserstein GAN (Adversarial Loss). Для стабильного обучения лучше дополнительно сравнивать фичи из спектрограмм.
- RLHF на человеческой разметке.

Для бейзлайна предлагаю использовать спектральные статистики + Reconstruction Loss (см ниже), чтобы не обучать дополнительные модели.




### Функции для расчёта лоссов

In [25]:
def spectral_centroid_fn(mel_spec):
  """Спектральный центроид"""
  mel_spec = mel_spec.unsqueeze(1)
  freqs = torch.linspace(0, 1, mel_spec.size(2), device=mel_spec.device)
  freqs = freqs.view(1, 1, -1, 1)
  weighted = mel_spec * freqs
  spectral_centroid = torch.sum(weighted, dim=2) / (torch.sum(mel_spec, dim=2) + 1e-8)
  return spectral_centroid

In [26]:
def spectral_bandwidth_fn(mel_spec):
  """
  Спектральная ширина - разброс частот
  """
  centroid = spectral_centroid_fn(mel_spec).unsqueeze(1)  # [batch, 1, time]

  mel_frequencies = torch.linspace(0, 1, config['mel_bins'], device=mel_spec.device)\
  .unsqueeze(0).unsqueeze(-1)

  # Нормализуем
  mel_spec_normalized = mel_spec / (torch.sum(mel_spec, dim=1, keepdim=True) + 1e-8)

  # Вычисляем дисперсию вокруг центроида
  freq_diff = (mel_frequencies - centroid) ** 2
  bandwidth = torch.sum(freq_diff * mel_spec_normalized, dim=1)  # [batch, time]

  return torch.sqrt(bandwidth + 1e-8)

In [27]:
def spectral_flatness_fn(mel_spec):
  """
  Спектральная плоскость - тональность vs шумность
  """
  # Геометрическое среднее
  geometric_mean = torch.exp(torch.mean(torch.log(mel_spec + 1e-8), dim=1))

  # Арифметическое среднее
  arithmetic_mean = torch.mean(mel_spec, dim=1)

  # Спектральная плоскость
  flatness = geometric_mean / (arithmetic_mean + 1e-8)

  return flatness

In [28]:
def spectral_entropy_fn(mel_spec):
  """
  Спектральная энтропия - сложность текстуры
  """
  # Нормализуем каждый временной кадр
  mel_normalized = mel_spec / (torch.sum(mel_spec, dim=1, keepdim=True) + 1e-8)

  # Вычисляем энтропию
  entropy = -torch.sum(mel_normalized * torch.log(mel_normalized + 1e-8), dim=1)

  # Нормализуем энтропию к [0, 1]
  max_entropy = torch.log(torch.tensor(mel_spec.size(1), device=mel_spec.device))
  normalized_entropy = entropy / max_entropy

  return normalized_entropy

In [29]:
def normalize(tensor):
  return tensor / tensor.norm()

In [30]:
def calc_loss(func, preds, y):
  preds = func(preds)
  y = func(y)
  return ((preds-y)**2).mean()

### Тренировочный цикл

In [31]:
# class LSTMModel(nn.Module):
#   def __init__(self
#                , mel_bins
#                , sequence_length_input
#                , sequence_length_output
#                , dropout_rate
#                , use_bidirectional
#                , device=device
#                , dtype=torch.float32):
#     super().__init__()

#     self.mel_bins = mel_bins
#     self.sequence_length_input = sequence_length_input
#     self.sequence_length_output = sequence_length_output
#     self.dropout_rate = dropout_rate
#     self.use_bidirectional = use_bidirectional
#     self.correction_coef = 1 if not use_bidirectional else 1/2


#     self.lstm1 = nn.LSTM(
#             input_size=self.sequence_length_input,
#             hidden_size=int(self.sequence_length_input*self.correction_coef),
#             num_layers=self.mel_bins,
#             batch_first=True,
#             bidirectional=self.use_bidirectional,
#             dropout=self.dropout_rate if self.mel_bins > 1 else 0
#             , device=device
#             , dtype=dtype
#     )
#     # self.lstm2 = nn.LSTM(
#     #         input_size=self.sequence_length_input,
#     #         hidden_size=int(self.sequence_length_input*self.correction_coef),
#     #         num_layers=self.mel_bins,
#     #         batch_first=True,
#     #         bidirectional=self.use_bidirectional,
#     #         dropout=self.dropout_rate if self.mel_bins > 1 else 0
#     #         , device=device
#     #         , dtype=dtype
#     # )
#     # self.lstm3 = nn.LSTM(
#     #         input_size=self.sequence_length_input,
#     #         hidden_size=int(self.sequence_length_input*self.correction_coef),
#     #         num_layers=self.mel_bins,
#     #         batch_first=True,
#     #         bidirectional=self.use_bidirectional,
#     #         dropout=self.dropout_rate if self.mel_bins > 1 else 0
#     #         , device=device
#     #         , dtype=dtype
#     # )
#     self.dropout = nn.Dropout(self.dropout_rate)
#     self.final_layer = nn.Linear(self.sequence_length_input
#                                  , self.sequence_length_output
#                                  , device=device
#             , dtype=dtype)

#   def _continue(self, X):

#     X, (hidden, cell) = self.lstm1(X) # в статье 3 раза lstm
#     # X, (hidden, cell) = self.lstm2(X, (hidden, cell))
#     # X, (hidden, cell) = self.lstm3(X, (hidden, cell))

#     hidden = hidden.view(-1, self.mel_bins, self.sequence_length_input)
#     return self.final_layer(X), self.final_layer(hidden)

#   def forward(self, X):
#     next_seq, hidden = self._continue(X)
#     return next_seq, hidden


# class BassFromHiddenModel(nn.Module):
#   def __init__(self
#                , sequence_length_input
#                , sequence_length_output
#                , device=device
#             , dtype=torch.float32):
#     super().__init__()
#     self.sequence_length_input = sequence_length_input
#     self.sequence_length_output = sequence_length_output
#     self.bass_parameters = nn.Parameter(torch.rand(self.sequence_length_input
#                                                    , self.sequence_length_output
#                                                    , device=device, dtype=dtype)
#     , requires_grad=True)
#     self.lr1 = nn.Linear(self.sequence_length_output
#                          , self.sequence_length_output//4
#                          , device=device
#             , dtype=dtype)
#     self.lr2 = nn.Linear(self.sequence_length_output//4
#                          , self.sequence_length_output
#                          , device=device
#             , dtype=dtype)
#     self.activ = nn.Sigmoid()

#   def forward(self, X):
#     return self.lr2(self.activ(self.lr1(X+self.bass_parameters)))


# class DrumsFromHiddenModel(nn.Module):
#   def __init__(self
#                , sequence_length_input
#                , sequence_length_output
#                , device=device
#             , dtype=torch.float32):
#     super().__init__()
#     self.sequence_length_input = sequence_length_input
#     self.sequence_length_output = sequence_length_output
#     self.drums_parameters = nn.Parameter(torch.rand(self.sequence_length_input
#                                                    , self.sequence_length_output
#                                                    , device=device, dtype=dtype)
#     , requires_grad=True)
#     self.lr1 = nn.Linear(self.sequence_length_output
#                          , self.sequence_length_output//4
#                          , device=device
#             , dtype=dtype)
#     self.lr2 = nn.Linear(self.sequence_length_output//4
#                          , self.sequence_length_output
#                          , device=device
#             , dtype=dtype)
#     self.activ = nn.Sigmoid()

#   def forward(self, X):
#     return self.lr2(self.activ(self.lr1(X+self.drums_parameters)))


In [32]:



class BassFromHiddenModel(nn.Module):
  def __init__(self
               , sequence_length_input
               , sequence_length_output
               , device=device
            , dtype=torch.float32):
    super().__init__()
    self.sequence_length_input = sequence_length_input
    self.sequence_length_output = sequence_length_output
    self.bass_parameters = nn.Parameter(torch.rand(self.sequence_length_input
                                                   , self.sequence_length_output
                                                   , device=device, dtype=dtype)
    , requires_grad=True)
    self.lr1 = nn.Linear(self.sequence_length_output
                         , self.sequence_length_output//4
                         , device=device
            , dtype=dtype)
    self.lr2 = nn.Linear(self.sequence_length_output//4
                         , self.sequence_length_output
                         , device=device
            , dtype=dtype)
    self.activ = nn.Sigmoid()

  def forward(self, X):
    return self.lr2(self.activ(self.lr1(X+self.bass_parameters)))


class DrumsFromHiddenModel(nn.Module):
  def __init__(self
               , sequence_length_input
               , sequence_length_output
               , device=device
            , dtype=torch.float32):
    super().__init__()
    self.sequence_length_input = sequence_length_input
    self.sequence_length_output = sequence_length_output
    self.drums_parameters = nn.Parameter(torch.rand(self.sequence_length_input
                                                   , self.sequence_length_output
                                                   , device=device, dtype=dtype)
    , requires_grad=True)
    self.lr1 = nn.Linear(self.sequence_length_output
                         , self.sequence_length_output//4
                         , device=device
            , dtype=dtype)
    self.lr2 = nn.Linear(self.sequence_length_output//4
                         , self.sequence_length_output
                         , device=device
            , dtype=dtype)
    self.activ = nn.Sigmoid()

  def forward(self, X):
    return self.lr2(self.activ(self.lr1(X+self.drums_parameters)))


class LSTMModel(nn.Module):
  def __init__(self
               , mel_bins
               , sequence_length_input
               , sequence_length_output
               , dropout_rate
               , use_bidirectional
               , device=device
               , dtype=torch.float32):
    super().__init__()

    self.mel_bins = mel_bins
    self.sequence_length_input = sequence_length_input
    self.sequence_length_output = sequence_length_output
    self.dropout_rate = dropout_rate
    self.use_bidirectional = use_bidirectional
    self.correction_coef = 1 if not use_bidirectional else 1/2

    # self.bass_model = BassFromHiddenModel(sequence_length_input=config['mel_bins']
    #            , sequence_length_output=config['sequence_length_output']).to(device)

    # self.drums_model = DrumsFromHiddenModel(sequence_length_input=config['mel_bins']
    #            , sequence_length_output=config['sequence_length_output']).to(device)


    self.lstm1 = nn.LSTM(
            input_size=self.sequence_length_input,
            hidden_size=int(self.sequence_length_input*self.correction_coef),
            num_layers=self.mel_bins,
            batch_first=True,
            bidirectional=self.use_bidirectional,
            dropout=self.dropout_rate if self.mel_bins > 1 else 0
            , device=device
            , dtype=dtype
    )
    self.lstm2 = nn.LSTM(
            input_size=self.sequence_length_input,
            hidden_size=int(self.sequence_length_input*self.correction_coef),
            num_layers=self.mel_bins,
            batch_first=True,
            bidirectional=self.use_bidirectional,
            dropout=self.dropout_rate if self.mel_bins > 1 else 0
            , device=device
            , dtype=dtype
    )
    self.lstm3 = nn.LSTM(
            input_size=self.sequence_length_input,
            hidden_size=int(self.sequence_length_input*self.correction_coef),
            num_layers=self.mel_bins,
            batch_first=True,
            bidirectional=self.use_bidirectional,
            dropout=self.dropout_rate if self.mel_bins > 1 else 0
            , device=device
            , dtype=dtype
    )
    self.dropout = nn.Dropout(self.dropout_rate)
    self.final_layer = nn.Linear(self.sequence_length_input
                                 , self.sequence_length_output
                                 , device=device
            , dtype=dtype)

  def _continue(self, X):

    next_seq, (hidden, cell) = self.lstm1(X) # в статье 3 раза lstm
    bass, _ = self.lstm2(next_seq, (hidden, cell))
    drums, _ = self.lstm3(next_seq, (hidden, cell))

    # hidden = hidden.view(-1, self.mel_bins, self.sequence_length_input)
    return self.final_layer(next_seq), self.final_layer(bass), self.final_layer(drums)

  def forward(self, X, batch_size):
    next_seq, bass_pred, drums_pred = self._continue(X)
    return next_seq, bass_pred[:batch_size], drums_pred[:batch_size]


In [33]:
# from google.colab import userdata
# import os
# os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')

In [34]:
# import wandb
# run = wandb.init(
#     project="jam-assistant",
#     config=config,
# )

In [35]:
model = LSTMModel(mel_bins=config['mel_bins']
               , sequence_length_input=config['sequence_length_input']
               , sequence_length_output=config['sequence_length_output']
               , dropout_rate=config['dropout_rate']
               , use_bidirectional=config['use_bidirectional']).to(device)
# bass_model = BassFromHiddenModel(sequence_length_input=config['mel_bins']
#                , sequence_length_output=config['sequence_length_output']).to(device)
# drums_model = DrumsFromHiddenModel(sequence_length_input=config['mel_bins']
#                , sequence_length_output=config['sequence_length_output']).to(device)
learning_rate = config['learning_rate']
optim = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
melspec = T.MelSpectrogram(config['target_sample_rate']
                           , n_fft=config['n_fft']
                           , n_mels=config['mel_bins']
                           , hop_length=config['hop_length']).to(device)

# optim_bass = torch.optim.RMSprop(bass_model.parameters(), lr=learning_rate)
# optim_drums = torch.optim.RMSprop(drums_model.parameters(), lr=learning_rate)
num_epochs = config['num_epochs']
scaler = torch.amp.GradScaler(device)

for epoch in range(num_epochs):
  model.train()
  optim.zero_grad()


  for batch in mus_dl_train:

    # сначала учим генерировать продолжение
    with torch.autocast(device_type=device
                        , dtype=torch.float16):

      vocals_X, vocals_y =  batch['vocals']
      drums_X, drums_y = batch['drums']
      bass_X, bass_y = batch['bass']
      other_X, other_y = batch['other']
      all_X, all_y = batch['other']

      samples_in_batch = vocals_X.shape[0]

      stems_X = torch.cat([vocals_X, drums_X, bass_X, other_X], dim=0).to(device)
      stems_y = torch.cat([vocals_y, drums_y, bass_y, other_y], dim=0).to(device)

      del vocals_X
      del vocals_y
      del drums_X
      del drums_y
      del bass_X
      del bass_y
      del other_X
      del other_y
      del all_X
      del all_y

      melspecs_X = melspec(stems_X)
      melspecs_y = melspec(stems_y)

      if config['target_log']:
        melspecs_X = torch.log1p(melspecs_X)
        melspecs_y = torch.log1p(melspecs_y)

      melspecs_pred, melspecs_pred_bass, melspecs_pred_drums = model(melspecs_X, batch_size=samples_in_batch)
      # hidden = hidden[:samples_in_batch]
      # melspecs_pred_bass = bass_model(hidden)
      # melspecs_pred_drums = drums_model(hidden)

      if config['target_log']:
        melspecs_pred = torch.expm1(melspecs_pred)
        melspecs_pred_bass = torch.expm1(melspecs_pred_bass)
        melspecs_pred_drums = torch.expm1(melspecs_pred_drums)

      normalized_pred = normalize(melspecs_pred)
      normalized_bass_pred = normalize(melspecs_pred_bass)
      normalized_drums_pred = normalize(melspecs_pred_drums)
      normalized_y = normalize(melspecs_y)

      # без нормализации - гигантские значения лосса

      reconstruction_loss = ((normalized_pred-normalized_y)**2).sum()
      reconstruction_loss_without_normalization = (
          (melspecs_pred-melspecs_y)**2
          ).mean()

      spectral_centroid_loss = calc_loss(spectral_centroid_fn
                                        , melspecs_pred
                                        , melspecs_y)
      # spectral_bandwidth_loss = calc_loss(spectral_bandwidth_fn
      #                                     , melspecs_pred
      #                                     , melspecs_y)
      # spectral_flatness_loss = calc_loss(spectral_flatness_fn
      #                                   , melspecs_pred
      #                                   , melspecs_y)
      # spectral_entropy_loss = calc_loss(spectral_entropy_fn
      #                                   , melspecs_pred
      #                                   , melspecs_y)

      reconstruction_loss_bass = ((normalized_bass_pred-normalized_y[samples_in_batch*2:samples_in_batch*3])**2).sum()
      reconstruction_loss_drums = ((normalized_drums_pred\
                                    -normalized_y[samples_in_batch*1:samples_in_batch*2])**2).sum()
      # run.log({'reconstruction_loss_train':reconstruction_loss,
              #  'reconstruction_loss_bass_train':reconstruction_loss_bass,
              #  'reconstruction_loss_drums_train':reconstruction_loss_drums})



    scaler.scale(reconstruction_loss).backward(retain_graph=True)
    scaler.scale(reconstruction_loss_bass).backward(retain_graph=True)
    scaler.scale(reconstruction_loss_drums).backward(retain_graph=True)

    
    # reconstruction_loss.backward(retain_graph=True)
    # reconstruction_loss_bass.backward(retain_graph=True)
    # reconstruction_loss_drums.backward(retain_graph=True)

    if config['spectral_backward']:
      scaler.scale(spectral_centroid_loss).backward(retain_graph=True)
    #   spectral_bandwidth_loss.backward(retain_graph=True)
    #   spectral_flatness_loss.backward(retain_graph=True)
    #   spectral_entropy_loss.backward(retain_graph=True)
  scaler.step(optim)
  # scaler.step(optim_bass)
  # scaler.step(optim_drums)

  scaler.update()
  with torch.no_grad():
    del batch
    del stems_X
    del stems_y
    del melspecs_X
    del melspecs_y
    del melspecs_pred
    torch.cuda.empty_cache()
    gc.collect()
  # optim.step()
  # optim_bass.step()
  # optim_drums.step()
    # break


  model.eval()

  with torch.no_grad():
    with torch.autocast(device_type=device
                    , dtype=torch.float16):
      tmp_losses = {'reconstruction_loss':[]
                    , 'reconstruction_loss_bass':[]
                    , 'reconstruction_loss_drums':[]
                    , 'spectral_centroid_loss':[]
                    , 'spectral_bandwidth_loss':[]
                    , 'spectral_flatness_loss':[]
                    , 'spectral_entropy_loss':[]}
      for batch in mus_dl_test:
        vocals_X, vocals_y =  batch['vocals']
        drums_X, drums_y = batch['drums']
        bass_X, bass_y = batch['bass']
        other_X, other_y = batch['other']
        all_X, all_y = batch['other']

        stems_X = torch.cat([vocals_X, drums_X, bass_X, other_X], dim=0).to(device)
        stems_y = torch.cat([vocals_y, drums_y, bass_y, other_y], dim=0).to(device)

        samples_in_batch = vocals_X.shape[0]

        melspecs_X = melspec(stems_X)
        melspecs_y = melspec(stems_y)
        if config['target_log']:
          melspecs_X = torch.log1p(melspecs_X)
          melspecs_y = torch.log1p(melspecs_y)

        melspecs_pred, melspecs_pred_bass, melspecs_pred_drums = model(melspecs_X, batch_size=samples_in_batch)
        # hidden = hidden[:vocals_X.shape[0]]
        # melspecs_pred_bass = bass_model(hidden)
        # melspecs_pred_drums = drums_model(hidden)

        if config['target_log']:
          melspecs_pred = torch.expm1(melspecs_pred)
          melspecs_pred_bass = torch.expm1(melspecs_pred_bass)
          melspecs_pred_drums = torch.expm1(melspecs_pred_drums)

        normalized_pred = normalize(melspecs_pred)
        normalized_bass_pred = normalize(melspecs_pred_bass)
        normalized_drums_pred = normalize(melspecs_pred_drums)
        normalized_y = normalize(melspecs_y)

        # без нормализации - гигантские значения лосса
        reconstruction_loss = ((normalized_pred-normalized_y)**2).sum()
        reconstruction_loss_without_normalization = (
            (melspecs_pred-melspecs_y)**2
            ).mean()

        spectral_centroid_loss = calc_loss(spectral_centroid_fn
                                          , melspecs_pred
                                          , melspecs_y)
        spectral_bandwidth_loss = calc_loss(spectral_bandwidth_fn
                                            , melspecs_pred
                                            , melspecs_y)
        spectral_flatness_loss = calc_loss(spectral_flatness_fn
                                          , melspecs_pred
                                          , melspecs_y)
        spectral_entropy_loss = calc_loss(spectral_entropy_fn
                                          , melspecs_pred
                                          , melspecs_y)

        reconstruction_loss_bass = ((normalized_bass_pred-normalized_y[bass_X.shape[0]*2:bass_X.shape[0]*3])**2).sum()
        reconstruction_loss_drums = ((normalized_drums_pred-normalized_y[drums_X.shape[0]*1:bass_X.shape[0]*2])**2).sum()

        tmp_losses['reconstruction_loss'].append(reconstruction_loss)
        tmp_losses['reconstruction_loss_bass'].append(reconstruction_loss_bass)
        tmp_losses['reconstruction_loss_drums'].append(reconstruction_loss_drums)
        tmp_losses['spectral_centroid_loss'].append(spectral_centroid_loss)
        tmp_losses['spectral_bandwidth_loss'].append(spectral_bandwidth_loss)
        tmp_losses['spectral_flatness_loss'].append(spectral_flatness_loss)
        tmp_losses['spectral_entropy_loss'].append(spectral_entropy_loss)

        del batch
        del stems_X
        del stems_y
        del melspecs_X
        del melspecs_y
        del vocals_X
        del vocals_y
        del drums_X
        del drums_y
        del bass_X
        del bass_y
        del other_X
        del other_y
        del all_X
        del all_y
        del melspecs_pred
        torch.cuda.empty_cache()
        gc.collect()


  print('epoch:', epoch)
  print('reconstruction loss:', sum(tmp_losses['reconstruction_loss'])/len(tmp_losses['reconstruction_loss']))
  print('reconstruction loss bass:', sum(tmp_losses['reconstruction_loss_bass'])/len(tmp_losses['reconstruction_loss_bass']))
  print('reconstruction loss drums:', sum(tmp_losses['reconstruction_loss_drums'])/len(tmp_losses['reconstruction_loss_drums']))
  print('spectral_centroid_loss:', sum(tmp_losses['spectral_centroid_loss'])/len(tmp_losses['spectral_centroid_loss']))
  print('spectral_bandwidth_loss:', sum(tmp_losses['spectral_bandwidth_loss'])/len(tmp_losses['spectral_bandwidth_loss']))
  print('spectral_flatness_loss:', sum(tmp_losses['spectral_flatness_loss'])/len(tmp_losses['spectral_flatness_loss']))
  print('spectral_entropy_loss:', sum(tmp_losses['spectral_entropy_loss'])/len(tmp_losses['spectral_entropy_loss']))
  # run.log(tmp_losses)

epoch: 0
reconstruction loss: tensor(2.0336, device='cuda:0')
reconstruction loss bass: tensor(1.1451, device='cuda:0')
reconstruction loss drums: tensor(1.1385, device='cuda:0')
spectral_centroid_loss: tensor(0.0916, device='cuda:0')
spectral_bandwidth_loss: tensor(nan, device='cuda:0')
spectral_flatness_loss: tensor(nan, device='cuda:0')
spectral_entropy_loss: tensor(nan, device='cuda:0')
epoch: 1
reconstruction loss: tensor(1.8463, device='cuda:0')
reconstruction loss bass: tensor(1.1050, device='cuda:0')
reconstruction loss drums: tensor(1.0825, device='cuda:0')
spectral_centroid_loss: tensor(0.0920, device='cuda:0')
spectral_bandwidth_loss: tensor(nan, device='cuda:0')
spectral_flatness_loss: tensor(nan, device='cuda:0')
spectral_entropy_loss: tensor(nan, device='cuda:0')
epoch: 2
reconstruction loss: tensor(1.7162, device='cuda:0')
reconstruction loss bass: tensor(1.0783, device='cuda:0')
reconstruction loss drums: tensor(1.0426, device='cuda:0')
spectral_centroid_loss: tensor(0.

In [36]:
import pickle as pkl
with torch.no_grad():
    torch.save(model.state_dict(), '/kaggle/working/lstm.pt')

In [37]:
# TODO:

# понять, почему спектральные метрики скатываются в nan?
# как можно объединить расчёт loss восстановленных и спектральных?
# код в более промышленном варианте - сделать функцию чтобы считать для каждого отдельно набора стемов и спектральные и реконструкцию
# попробовать восстановить звук из мелки
# расширить датасет