In [None]:
# ! mkdir -p all_wavs

In [None]:
# ! find data/asr_calls_2_val -type f -name '*.wav' -exec cp {} all_wavs/ \;

In [None]:
# ! python3 -m pip install librosa

Defaulting to user installation because normal site-packages is not writeable
Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m260.7/260.7 KB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting soxr>=0.3.2
  Downloading soxr-0.5.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (252 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.8/252.8 KB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting soundfile>=0.12.1
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hCollecting msgpack>=1.0
  Downloading msgpack-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (378 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m378.0/378.0 KB[0m [31m5.

In [8]:
import os
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
import numpy as np
import torchaudio

# Модель

In [9]:
config = { #FIX
    "dataset": {
        "train": {
            "table": "./data/train.csv",
            "data": "./data/asr_calls_spec/"
        },
        "val": {
            "table": "./data/val.csv",
            "data": "./data/asr_calls_spec/"
        }
    },
    "train": {
        "batch_size": 128,
        "grad_acum": 1,
        "dtype": "float32",
        'shuffle': True,
        'pin_memory': True,
    },
    "val": {
        "batch_size": 1024,
        "grad_acum": 1,
        "dtype": "float32",
        'shuffle': False,
        'pin_memory': True,
    },
    "vae": {
        "freq": 16000,
        "lenght": 5,
    },
    "model": {
        "latent_size": 128,
        "freq_scale": 4,
        "time_scale": 4,
    },
    "utils": {
        "n_fft": 800,
    }
}

In [10]:
from sklearn.model_selection import train_test_split

def split_dataset(file_paths, test_size=0.2, seed=42):
    return train_test_split(file_paths, test_size=test_size, random_state=seed)

In [11]:
import torch
import torchaudio
import librosa
import numpy as np
from torch.utils.data import Dataset

class MelSpectrogramDataset(Dataset):
    def __init__(
        self,
        wav_paths,              # список путей к .wav файлам
        sr=22050,               # sample rate
        duration=2.0,           # длительность аудио в секундах
        n_mels=128,             # высота спектрограммы
        hop_length=512,         # шаг окна
        transform=None,         # дополнительные преобразования
        to_log=True             # перевод в dB
    ):
        self.wav_paths = wav_paths
        self.sr = sr
        self.duration = duration
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.transform = transform
        self.to_log = to_log
        self.fixed_length = int(sr * duration)  # количество сэмплов

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, idx):
        path = self.wav_paths[idx]

        # Загружаем и обрезаем/дополняем
        y, _ = librosa.load(path, sr=self.sr)
        if len(y) < self.fixed_length:
            y = np.pad(y, (0, self.fixed_length - len(y)), mode='constant')
        else:
            y = y[:self.fixed_length]

        # Вычисляем мел-спектрограмму
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_mels=self.n_mels,
            hop_length=self.hop_length
        )

        # В логарифмический масштаб
        if self.to_log:
            mel = librosa.power_to_db(mel, ref=np.max)

        # Нормализация (опционально)
        mel = (mel - mel.min()) / (mel.max() - mel.min() + 1e-6)

        # В тензор (1, H, W)
        mel_tensor = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)

        if self.transform:
            mel_tensor = self.transform(mel_tensor)

        return mel_tensor

In [12]:
from torch.utils.data import DataLoader

def get_dataloaders(train_files, val_files, batch_size=32):
    train_dataset = MelSpectrogramDataset(train_files)
    val_dataset = MelSpectrogramDataset(val_files)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    return train_loader, val_loader

In [25]:
class VAE_Audio(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Энкодер
        self.encoder_input = nn.Sequential(
            nn.Conv2d(1, 1, 1),
            nn.Tanh(),
            nn.Conv2d(1, 8, 3, 1, 1),
            nn.BatchNorm2d(8),
            nn.GELU(),
            nn.Conv2d(8, 16, 3, 1, 1)
        )

        self.encoder_main = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, (1, 0)),  # Коррекция паддинга для нечётных размеров
            nn.GELU(),
            nn.Conv2d(32, 32, 3, 2, (1, 0)),
            nn.GELU(),
            nn.Conv2d(32, 32, 3, 2, (1, 0)),
            nn.GELU(),
            nn.Conv2d(32, 16, 3, 2, (1, 0))
        )

        self.encoder_squeeze = nn.Sequential(
            nn.Conv2d(16, 32, 3, 2, (1, 0)),
            nn.Conv2d(32, 64, 3, 2, (1, 0))
        )

        self.encoder_mu = nn.Conv2d(64, 64, 1)
        self.encoder_logvar = nn.Conv2d(64, 64, 1)

        # Декодер
        self.decoder_unsqueeze = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, 2, (1, 0), output_padding=(1, 0)),
            nn.ConvTranspose2d(32, 16, 3, 2, (1, 0), output_padding=(1, 0)),
            nn.ConvTranspose2d(16, 32, 3, 2, (1, 0), output_padding=(1, 0)),
            nn.ConvTranspose2d(32, 32, 3, 2, (1, 0), output_padding=(1, 0)),
            nn.ConvTranspose2d(32, 16, 3, 2, (1, 0), output_padding=(1, 0))
        )

        self.decoder_output = nn.Sequential(
            nn.Conv2d(16, 8, 3, padding=(1, 0)),
            nn.BatchNorm2d(8),
            nn.GELU(),
            nn.Conv2d(8, 1, 1),
            nn.AdaptiveAvgPool2d((128, 87)),  # Форсируем нужный размер
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder_input(x)
        x = self.encoder_main(x)
        x = self.encoder_squeeze(x)
        return self.encoder_mu(x), self.encoder_logvar(x)

    def decode(self, z):
        z = self.decoder_unsqueeze(z)
        return self.decoder_output(z)

    # Остальные методы остаются без изменений
    def sample(self, x):
        mu, logvar = self.encode(x)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z, mu, logvar

    def KLD_loss(self, mu, logvar, q=0.02):
        kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
        kld = torch.clamp(kld, min=q)
        return kld.mean()

    def forward(self, x):
        z, mu, logvar = self.sample(x)
        return self.decode(z), z, mu, logvar

In [23]:
from tqdm.auto import tqdm

class AvegereMeter:
    def __init__(self,):
        self.arr = []
    def __call__(self, item, n=1):
        if n<=1:
            self.arr.extend([item])
        else:
            self.arr.extend([item]*n)
    def __str__(self,) -> str:
        return str(np.mean(np.array(self.arr)))
    def zero(self,):
        self.arr=[]

class VAE_Trainer:
    def __init__(self, model, train_dataloader, val_dataloader,):
        self.model = model
        self.tdl = train_dataloader
        self.vdl = val_dataloader
        self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        self.rec_loss = nn.MSELoss()
        self.loss_meter = AvegereMeter()
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5)

    def train_loop(self, k=0.01):
        self.model.train()
        self.loss_meter.zero()
        for batch in tqdm(self.tdl):
            with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                z, mu, logvar = self.model.sample(batch)
                output = self.model.decode(z)
                loss = self.rec_loss(output, batch)+k*self.model.KLD_loss(mu, logvar)
            loss.backward()
            self.loss_meter(loss.item(), batch.shape[0])
            self.optimizer.step()
            self.optimizer.zero_grad()
        self.scheduler.step()
        print("Loss = "+self.loss_meter.__str__())

    def save_image(self, audio, output, iter = 0):
        idx = torch.randint(0, audio.shape[0], (1,)).item()
        input_tensor = audio[idx].cpu().detach().clamp(0, 1).to(torch.float32)
        output_tensor = output[idx].cpu().detach().clamp(0, 1).to(torch.float32)

        transform = transforms.ToPILImage('RGB')
        input_image = transform(torch.cat([input_tensor]*3, dim=0))
        output_image = transform(torch.cat([output_tensor]*3, dim=0))

        input_image.save(f'./data/asr_calls_spec/train/input_{iter}.png')
        output_image.save(f'./data/asr_calls_spec/train/output_{iter}.png')

    def val_loop(self):
        self.model.eval()
        self.loss_meter.zero()
        flag = 1
        for batch in tqdm(self.vdl):
            with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                z, mu, logvar = self.model.sample(batch)
                output = self.model.decode(z)
                loss = self.rec_loss(output, batch)
                if (flag):
                    self.save_image(batch, output, i)
                flag = 0
                self.loss_meter(loss.item(), batch.shape[0])
        print("Val loss = "+self.loss_meter.__str__())


In [None]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

vae = VAE_Audio()

file_paths = list(Path('data/asr_calls_wavs').rglob("*.wav"))
train_files, val_files = split_dataset(file_paths)

train_dataloader, val_dataloader = get_dataloaders(train_files, val_files, batch_size=config['train']['batch_size'])



trainer = VAE_Trainer(vae, train_dataloader, val_dataloader)


i = 0
for epoch in tqdm(range(10)):
    i += 1
    trainer.train_loop(0.01)
    trainer.val_loop()
torch.save(vae.state_dict(), "vae_spec.pt")

  0%|                                                                                            | 0/10 [00:00<?, ?it/s]
  0%|                                                                                            | 0/83 [00:00<?, ?it/s][A
  1%|█                                                                                   | 1/83 [00:23<31:35, 23.12s/it][A
  2%|██                                                                                  | 2/83 [00:45<30:50, 22.85s/it][A
  4%|███                                                                                 | 3/83 [01:07<29:51, 22.39s/it][A
  5%|████                                                                                | 4/83 [01:30<29:35, 22.48s/it][A
  6%|█████                                                                               | 5/83 [01:52<29:02, 22.34s/it][A
  7%|██████                                                                              | 6/83 [02:12<27:55, 21.76s/it][A
  8%|██████

Loss = 0.03672541230334624



  0%|                                                                                            | 0/21 [00:00<?, ?it/s][A
  5%|████                                                                                | 1/21 [00:09<03:08,  9.42s/it][A
 10%|████████                                                                            | 2/21 [00:19<03:04,  9.72s/it][A
 14%|████████████                                                                        | 3/21 [00:28<02:52,  9.59s/it][A
 19%|████████████████                                                                    | 4/21 [00:36<02:29,  8.77s/it][A