In [2]:
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
import numpy as np
import torchaudio
class VAE_Audio(nn.Module):
    def __init__(self,):
        super().__init__()
        self.encoder_input = nn.Sequential(*[
                nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1),
                nn.LazyBatchNorm2d(),
                nn.GELU(),
                nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, stride=1, padding=1),
            ])
        self.encoder_squeeze = nn.Sequential(*[
                nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=2, padding=1), # TODO перепроверить размеры
                nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, stride=2, padding=1), # TODO перепроверить размеры
            ])
        self.encoder_mu = nn.Conv2d(1, 1, 1)
        self.encoder_logvar = nn.Conv2d(1, 1, 1)
        self.decoder_unsqueeze = nn.Sequential(*[
                nn.ConvTranspose2d(in_channels=1, out_channels=2, kernel_size=3, stride=2, padding=1, output_padding=1), # TODO перепроверить размеры
                nn.ConvTranspose2d(in_channels=2, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1), # TODO перепроверить размеры
            ])
        self.decoder_output = nn.Sequential(*[
                nn.ConvTranspose2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1),
                nn.GELU(),
                nn.LazyBatchNorm2d(),
                nn.ConvTranspose2d(in_channels=2, out_channels=1, kernel_size=3, stride=1, padding=1),
            ])
    def encode(self, x):
        x = self.encoder_input(x)
        x = self.encoder_squeeze(x)
        mu = self.encoder_mu(x)
        logvar = self.encoder_logvar(x)
        return mu, logvar
    def sample(self, x):
        mu, logvar = self.encode(x)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + logvar*eps
        return z, mu, logvar
    def decode(self, x):
        x = self.decoder_unsqueeze(x)
        x = self.decoder_output(x)
        return x
    def KLD_loss(self, mu, logvar):
        kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
        return kld_loss
    def forward(self, x):
        z, mu, logvar = self.sample(x)
        return self.decode(z), z, mu, logvar

In [3]:
from tqdm.auto import tqdm
class AvegereMeter:
    def __init__(self,):
        self.arr = []
    def __call__(self, item, n=1):
        if n<=1:
            self.arr.extend([item])
        else:
            self.arr.extend([item]*n)
    def __str__(self,) -> str:
        return str(np.mean(np.array(self.arr)))
    def zero(self,):
        self.arr=[]

class VAE_Trainer:
    def __init__(self, model, train_dataloader, val_dataloader,):
        self.model = model
        self.tdl = train_dataloader
        self.vdl = val_dataloader
        self.optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        self.rec_loss = nn.MSELoss()
        self.loss_meter = AvegereMeter()
        # self.scheduler

    def train_loop(self, k=0.01):
        self.model.train()
        self.loss_meter.zero()
        for batch in tqdm(self.tdl):
            audio = batch[::, 0]
            z, mu, logvar = self.model.sample(audio)
            output = self.model.decode(z)
            loss = self.rec_loss(output, audio)+k*self.model.KLD_loss(mu, logvar)
            loss.backward()
            self.loss_meter(loss.item(), batch.shape[0])
            self.optimizer.step()
            self.optimizer.zero_grad()
        #self.scheduler.step()
        print("Loss = "+self.loss_meter.__str__())
    def val_loop(self):
        self.model.eval()
        self.loss_meter.zero()
        for batch in tqdm(self.tdl):
            with torch.no_grad():
                audio = batch[::, 0]
                z, mu, logvar = self.model.sample(audio)
                output = self.model.decode(mu)
                loss = self.rec_loss(output, audio)
                self.loss_meter(loss.item(), batch.shape[0])
        print("Val loss = "+self.loss_meter.__str__())


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

vae = VAE_Audio()

trainer = VAE_Trainer(vae, train_dataloader, val_dataloader)
for epoch in tqdm(range(100)):
    trainer.train_loop(0.01)
    trainer.val_loop()
torch.save(vae.state_dict(), "vae.pt")


100%|██████████| 9.91M/9.91M [00:00<00:00, 20.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 615kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.74MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.95MB/s]
  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s]


TypeError: list indices must be integers or slices, not tuple