In [87]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample

device = 'mps' if torch.backends.mps.is_available() else 'cpu'

## FÖR ATT ÅTERSKAPA: test_output = Resample(resample_rate,48000)(next(iter(dataset)))
## FÖR ATT SPARA: torchaudio.save("./test_output.wav",test_output.unsqueeze(0),48000)


class MP3Dataset(Dataset):
    def __init__(self, file_paths, resample_rate):
        self.file_paths = file_paths
        self.resample_rate = resample_rate
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        waveform, sample_rate = librosa.load(self.file_paths[idx],sr=48000)
        waveform = torch.Tensor(waveform)
        waveform = Resample(sample_rate, self.resample_rate)(waveform)
        waveform = waveform.to(device)
        return waveform


# Create a DataLoader for your dataset
# You need to replace 'file_paths' with the actual paths to your MP3 files
data_dir = "./output_chunks/"
file_paths = [f for f in os.listdir(data_dir) if f.endswith(".mp3")]
file_paths = [os.path.join(data_dir, x) for x in file_paths]

resample_rate = 12000
dataset = MP3Dataset(file_paths,resample_rate)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [103]:
import torch
import librosa
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)
        self.fc22 = nn.Linear(hidden_size, latent_size)

        # Decoder
        self.fc3 = nn.Linear(latent_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return z

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
class loss_function(nn.Module):
    def __init__(self):
        super(loss_function,self).__init__()

    def forward(self, recon_x, x, mu, logvar):
        MSE = F.mse_loss(recon_x,x)

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = 0.12 * torch.sum(1 + logvar - mu**2 - torch.exp(logvar))

        return MSE - KLD

# Initialize the VAE
input_size = 180000 # MNIST images are 28x28 = 784
hidden_size = 400
latent_size = 20
vae = VAE(input_size, hidden_size, latent_size)
vae.to(device)

# Define the optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-5)
VAEloss = loss_function()


In [107]:
## Training Loop

vae.train()
n_epochs = 100
for epoch in range(n_epochs):
    batch = next(iter(dataloader))
    decoded, mu, logvar = vae(batch)

    loss = VAEloss(decoded,batch,mu,logvar)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("loss: ", loss)

loss:  tensor(0.3183, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3455, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.4098, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3315, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.5399, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.4813, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3599, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.5283, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3793, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3316, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3239, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3260, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3665, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3762, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3727, device='mps:0', grad_fn=<SubBackward0>)
loss:  tensor(0.3531, device='mps:0', grad_fn=<SubBackward0>)
loss:  t

KeyboardInterrupt: 

<enumerate at 0x294f86c00>

In [111]:
test_output = (next(iter(dataset)))
test_output = vae(test_output)[0].to('cpu')
test_output = Resample(resample_rate,48000)(test_output)
torchaudio.save("./test_output.wav",test_output.unsqueeze(0),48000)