In [28]:
import torch, torchvision, torchaudio, numpy as np, matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from pathlib import Path
import pandas as pd
import torchvision.utils as vutils
import torch.backends.cudnn as cudnn
torch.backends.cudnn.benchmark = True
device = torch.device('cuda') 

In [29]:
device

device(type='cuda')

# Конфигурирование

In [30]:
config = {
    "dataset": {
        "train": {
            "table": "E:/data/train.csv",
            "data": "E:/data/bare_data/"
        },
        "val": {
            "table": "E:/data/val.csv",
            "data": "E:/data/bare_data/"
        }
    },
    "train": {
        "batch_size": 32,
        "grad_acum": 1,
        "dtype": "float32",
        'shuffle': True,
        'pin_memory': True,
    },
    "val": {
        "batch_size": 32,
        "grad_acum": 1,
        "dtype": "float32",
        'shuffle': False,
        'pin_memory': True,
    },
    "vae": {
        "freq": 16000,
        "lenght": 5,
    },
    "model": {
        "latent_size": 128,
        "epochs": 15,
        "learning_rate": 0.001,
        "freq_scale": 4,
        "time_scale": 4,
    },
    "utils": {
        "n_fft": 800,
    }
}

# Загрузка данных

In [31]:
mel_spec = torchaudio.transforms.MelSpectrogram(
    sample_rate = 16000,
    n_fft = 800,
    hop_length = 200,     
    win_length  = 800,
    n_mels = 80,
)

def wav_to_mel(wav):
    # wav: (1, L)
    spec = mel_spec(wav)      
    spec = torch.log(spec + 1e-6)
    T = spec.shape[-1]
    if T < 80:                    
        spec = F.pad(spec, (0, 80 - T))
    elif T > 80:                  
        spec = spec[..., :80]
    return spec  

class AudioDataset(Dataset):
    def __init__(self, csv_file, audio_dir, target_sr=16000, length_sec=None, transform=None):
        self.table = pd.read_csv(csv_file)
        self.audio_dir = Path(audio_dir)
        self.sr = target_sr
        self.length = int(target_sr * length_sec) if length_sec else None
        self.transform = transform

    def load_wav(self, path):
        wav, sr = torchaudio.load(path)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        if sr != self.sr:
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        return wav

    def pad_trim(self, wav):
        if self.length is None:
            return wav
        cur = wav.shape[-1]
        if cur > self.length:
            wav = wav[..., : self.length]
        elif cur < self.length:
            wav = torch.nn.functional.pad(wav, (0, self.length - cur))
        return wav

    def __getitem__(self, idx):
        row = self.table.iloc[idx]
        rel_path = row["path"]
        text = row["sentence"]
        wav = self.load_wav(self.audio_dir / rel_path)
        wav = self.pad_trim(wav).float()
        mel = wav_to_mel(wav)

        if self.transform:
            mel = self.transform(mel)

        return mel, text

    def __len__(self):
        return len(self.table)


def audio_collate(batch):
    mels, texts = zip(*batch)
    return torch.stack(mels), list(texts)                 


from torch.utils.data import DataLoader, SubsetRandomSampler

def build_dataloader(cfg, split, transform=None, workers=4, limit=10_000):
    d  = cfg["dataset"][split]
    ds = AudioDataset(
        d["table"], d["data"],
        cfg["vae"]["freq"], cfg["vae"]["lenght"],
        transform,
    )
    if limit and limit < len(ds):
        idx = np.random.choice(len(ds), limit, replace=False)
        ds = torch.utils.data.Subset(ds, idx)    
    sampler = None

    return DataLoader(
        ds,
        batch_size = cfg[split]["batch_size"],
        shuffle = (split == "train"),
        sampler = sampler,
        num_workers = workers,
        pin_memory = cfg[split]["pin_memory"],
        collate_fn = audio_collate,
    )


## Модель


### Код

In [32]:
# Класс для перевода объекта в латентное пространство для упрощённой работы с ним, подаём объект -> получаем тензор меток на него, т.е. какими признаками он обладает и с помощью этого можем его сравнивать с другими и обучать модель.

class VAE_Audio(nn.Module):
    def __init__(self,):
        super().__init__()
        # Слой получает признаки из исходного объекта
        self.encoder_input = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
            nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
        )
        # Слой сжимает объект в латентное подпространство
        self.encoder_squeeze = nn.Sequential(
            nn.Conv2d(32, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.GELU(),
            nn.Conv2d(32, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.GELU(),
        )
        # Нужны, чтобы задать параметры Гауссовского распределения (mu - среднее, logvar - логарифм дисперсии)
        self.encoder_mu = nn.Conv2d(32, 32, 1) # Набор меток, отвечающих за среднее значение признаков
        self.encoder_logvar = nn.Conv2d(32, 32, 1) # Набор меток, отвечающих за то, как широко разбросаны признаки по латентному подпространству 
        
        # Слои восстанавливают размерность до исходной
        self.decoder_unsqueeze = nn.Sequential(
                nn.ConvTranspose2d(32, 32, 3, 2, 1, output_padding=1), nn.BatchNorm2d(32), nn.GELU(),
                nn.ConvTranspose2d(32, 32, 3, 2, 1, output_padding=1), nn.BatchNorm2d(32), nn.GELU(),
        )
        # Слой, отвечающий за выход: получаем объект, максимально похожий на исходный
        self.decoder_output = nn.Sequential(
                 nn.ConvTranspose2d(32, 16, 3, 1, 1), nn.GELU(),
                 nn.BatchNorm2d(16),
                 nn.ConvTranspose2d(16, 1, 3, 1, 1), 
        )
    def encode(self, x): # Функция кодирования объекта в латентное подпространство (пространство меньшей размерности), получает параметры кодирования объекта (mu, logvar)
        x = self.encoder_input(x)
        x = self.encoder_squeeze(x)
        mu = self.encoder_mu(x)
        logvar = self.encoder_logvar(x)
        return mu, logvar
    def decode(self, x): # Декодирует объект в исходное пространство по набору признаков
        x = self.decoder_unsqueeze(x)
        x = self.decoder_output(x)
        return x
    def KLD_loss(self, mu, logvar, q=0.005): # Вычисляет Kullback-Leibler divergence (мера различия между двумя вероятностными распределениями) между предсказанным распределением и стандартным нормальным распределением
        kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
        kld = torch.clamp(kld, min=q)
        return kld.mean()
    def forward(self, x): # Полное прохождение через VAE
        mu, logvar = self.encode(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        recon = self.decode(z)
        return recon, z, mu, logvar


In [37]:
from tqdm.notebook import tqdm
class AverageMeter:
    def __init__(self,):
        self.arr = []
    def __call__(self, item, n=1):
        if n<=1:
            self.arr.extend([item])
        else:
            self.arr.extend([item]*n)
    def __str__(self,) -> str:
        return str(np.mean(np.array(self.arr)))
    def zero(self,):
        self.arr=[]

class VAE_Trainer:
    def __init__(self, model, train_dataloader, val_dataloader,):
        self.model = model
        self.device = next(model.parameters()).device
        self.tdl = train_dataloader
        self.vdl = val_dataloader
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-3)
        self.rec_loss = nn.L1Loss(reduction="mean")
        self.loss_meter = AverageMeter()

    @staticmethod
    def _show_example(model, loader, epoch, device):
        mel, _ = next(iter(loader))
        mel = mel.to(device)
        with torch.no_grad():
            recon, _, mu, logvar = model(mel[:1])
        fig, ax = plt.subplots(1, 2, figsize=(6, 2.5))
        ax[0].imshow(mel[0, 0].cpu(), origin='lower', aspect='auto', cmap="magma"); ax[0].set_title('orig')
        ax[1].imshow(recon[0, 0].cpu(), origin='lower', aspect='auto', cmap="magma"); ax[1].set_title(f'recon e{epoch}')
        for a in ax: a.axis('off')
        plt.show()
        plt.close(fig)

    def train_loop(self, epoch=0):
        self.model.train()
        self.loss_meter.zero()
        pbar = tqdm(self.tdl, desc=f'train e{epoch}', leave=False)
        for original_audio, _ in pbar:
            original_audio = original_audio.to(self.device, non_blocking=True)
            output, _, mu, logvar = self.model(original_audio)
            recon = self.rec_loss(output, original_audio)
            KLD = self.model.KLD_loss(mu, logvar)
            beta = min(1.0, epoch / 20) * 0.5
            loss = recon + beta * KLD
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.loss_meter(loss.item(), original_audio.size(0))
            pbar.set_postfix(recon=f'{recon.item():.4f}', beta=f'{beta:.4f}', kld=f'{(beta*KLD).item():.6f}')
        print("Train Loss = " + self.loss_meter.__str__())

    def val_loop(self, epoch=None):
        self.model.eval()
        self.loss_meter.zero()
        mu_means, mu_stds = [], []
        lv_means, lv_stds = [], []
        for original_audio, _ in tqdm(self.vdl, desc='val'):
            with torch.no_grad():
                original_audio = original_audio.to(self.device, non_blocking=True)
                output, _, mu, logvar = self.model(original_audio)
                mu_means.append(mu.mean().item())
                mu_stds.append(mu.std().item())
                lv_means.append(logvar.mean().item())
                lv_stds.append(logvar.std().item())
                loss = self.rec_loss(output, original_audio)
                self.loss_meter(loss.item(), original_audio.size(0))
        print(f"Validation loss = {self.loss_meter}")
        print(f"Encoder μ:   mean = {sum(mu_means)/len(mu_means)}, std = {sum(mu_stds)/len(mu_stds)}")
        print(f"Encoder logσ²: mean = {sum(lv_means)/len(lv_means)}, std = {sum(lv_stds)/len(lv_stds)}")
        self._show_example(self.model, self.vdl, epoch, self.device)


In [38]:
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
train_dataloader = build_dataloader(config, "train", workers=0, limit=1000)
val_dataloader = build_dataloader(config, "val", workers=0, limit=None)
vae = VAE_Audio().to(device)
trainer = VAE_Trainer(vae, train_dataloader, val_dataloader)

  self.table = pd.read_csv(csv_file)


In [39]:
%%time
mel, _ = next(iter(train_dataloader))
print("batch shape:", mel.shape)


batch shape: torch.Size([32, 1, 80, 80])
CPU times: total: 1.17 s
Wall time: 174 ms


In [None]:
EPOCHS = 30
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    trainer.train_loop(epoch)
    trainer.val_loop(epoch)
    if ((epoch + 1) % 5 == 0):
        torch.save(vae.state_dict(), f"TTSVAE_v10.{epoch + 1}.pt")


Epoch 1/30


train e0:   0%|          | 0/32 [00:00<?, ?it/s]

Train Loss = 7.561288307189941


val:   0%|          | 0/532 [00:00<?, ?it/s]