Библиотеки

In [1]:
# Импорт PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchaudio

# Импорт NumPy
import numpy as np

# Импорт Matplotlib
import matplotlib.pyplot as plt

# Импорт Librosa
import librosa
import librosa.display

#импорт os
import os

пути для входных и целевых данных

In [2]:
mix_dir = "C:\для учебы\Диплом\Диплом маг\DataSet\SDataSet\mix"
inst_dir = "C:\для учебы\Диплом\Диплом маг\DataSet\SDataSet\instruments"

Класс дата-сета

In [None]:
class MusicSeparationDataset(Dataset):
    def __init__(self, mix_dir, instruments_dir):
        self.mix_dir = mix_dir
        self.instruments_dir = instruments_dir

        self.tracks = sorted([ f for f in os.listdir(mix_dir) if f.endswith('.flac')])

        self.instrument_classes = ['Bass','Drums','Guitars','Keys']
        
    def __len__(self):
        return len(self.tracks)
    
    def __getitem__(self, idx, sr=44100):
        self.sr = sr
        self.track_id = self.tracks[idx]
        track_name = os.path.splitext(self.track_id)[0]

        mix_path = os.path.join(self.mix_dir, track_name)
        mix_waveform, sr = torchaudio.load(mix_path)

        targets = []
        for inst in self.instrument_classes:
            inst_path = os.path.join(self.instruments_dir, inst, track_name)
            inst_waveform, _ = torchaudio.load(inst_path)
            targets.append(inst_waveform)
        
        target_tensor = torch.stack(targets)
        return mix_waveform, target_tensor


In [4]:
data_set = MusicSeparationDataset(mix_dir, inst_dir)

In [5]:
data_set[0][0].shape

torch.Size([1, 10652672])

In [6]:
data_set[0][1].shape

torch.Size([4, 1, 10652672])

In [7]:
len(data_set)

100

In [14]:
data_set[0][1].shape

torch.Size([4, 1, 10652672])

Data-loader

In [26]:
import torch
import torch.nn.functional as F

def collate_fn(batch):
    """
    Принимает список из N элементов, где каждый элемент — это (mix_waveform, target_tensor)
    mix_waveform: [1, T]
    target_tensor: [num_instruments, 1, T]
    """
    mixes, targets = zip(*batch)  # список миксов и список таргетов

    # Определим максимальную длину аудио в батче
    max_len = max(mix.shape[1] for mix in mixes)

    # Паддим миксы до max_len
    padded_mixes = [F.pad(mix, (0, max_len - mix.shape[1])) for mix in mixes]

    # Паддим каждый target (каждый target — это [N, 1, T])
    padded_targets = [F.pad(tgt, (0, max_len - tgt.shape[2])) for tgt in targets]

    # Собираем в тензоры батча
    mix_batch = torch.stack(padded_mixes)       # [B, 1, max_len]
    target_batch = torch.stack(padded_targets)  # [B, N, 1, max_len]

    return mix_batch, target_batch


In [31]:
dataloader = DataLoader(data_set, batch_size=4, collate_fn=collate_fn, shuffle=False)

In [32]:
for mix, targets in dataloader:
    print("Mix shape:", mix.shape)        # [4, 1, T]
    print("Targets shape:", targets.shape)  # [4, N, 1, T]
    break  # только один batch


Mix shape: torch.Size([4, 1, 11879936])
Targets shape: torch.Size([4, 4, 1, 11879936])


Модель

In [None]:
#Эмбеддинг-класс
class PatchEmbed(nn.Module):
    def __init__(self, patch_size, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.proj = nn.Linear(patch_size[0] * patch_size[1], embed_dim)

    def forward(self, x):  # x: [B, 1, T, F]
        B, C, T, F = x.shape
        t_p, f_p = self.patch_size
        x = x.unfold(2, t_p, t_p).unfold(3, f_p, f_p)  # [B, C, n_t, n_f, t_p, f_p]
        x = x.contiguous().view(B, -1, t_p * f_p)      # [B, Num_patches, Patch_dim]
        return self.proj(x)                            # [B, Num_patches, Embed_dim]


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ff(self.norm2(x))
        return x


In [None]:
class PatchDecoder(nn.Module):
    def __init__(self, patch_size, embed_dim, out_channels, out_shape):
        super().__init__()
        self.patch_size = patch_size
        self.out_shape = out_shape
        self.out_channels = out_channels
        self.proj = nn.Linear(embed_dim, patch_size[0] * patch_size[1] * out_channels)

    def forward(self, x):  # [B, Num_patches, Embed_dim]
        B, N, D = x.shape
        x = self.proj(x)  # [B, N, P×P×C]
        t_p, f_p = self.patch_size
        T, F = self.out_shape
        x = x.view(B, N, self.out_channels, t_p, f_p)
        # Расклеим обратно в 2D спектрограмму
        n_t = T // t_p
        n_f = F // f_p
        x = x.view(B, n_t, n_f, self.out_channels, t_p, f_p)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(B, self.out_channels, T, F)
        return x


In [None]:
class SimpleAudioTransformer(nn.Module):
    def __init__(self, patch_size=(16, 16), embed_dim=256, num_heads=4, ff_dim=512, num_layers=4, n_outputs=3, input_shape=(128, 128)):
        super().__init__()
        self.embed = PatchEmbed(patch_size, embed_dim)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])
        self.decoder = PatchDecoder(patch_size, embed_dim, n_outputs, input_shape)

    def forward(self, x):  # x: [B, 1, T, F]
        x = self.embed(x)          # [B, Num_patches, Embed_dim]
        x = self.transformer(x)    # [B, Num_patches, Embed_dim]
        x = self.decoder(x)        # [B, n_outputs, T, F]
        return x
