In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

## FÖR ATT ÅTERSKAPA: test_output = Resample(resample_rate,48000)(next(iter(dataset)))
## FÖR ATT SPARA: torchaudio.save("./test_output.wav",test_output.unsqueeze(0),48000)


class MP3Dataset(Dataset):
    def __init__(self, file_paths, resample_rate):
        self.file_paths = file_paths
        self.resample_rate = resample_rate
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        waveform, sample_rate = librosa.load(self.file_paths[idx],sr=48000)
        waveform = torch.Tensor(waveform).unsqueeze(0)
        waveform = Resample(sample_rate, self.resample_rate)(waveform)
        waveform = waveform.to(device)
        return waveform


# Create a DataLoader for your dataset
# You need to replace 'file_paths' with the actual paths to your MP3 files
data_dir = "./output_chunks/"
file_paths = [f for f in os.listdir(data_dir) if f.endswith(".mp3")]
file_paths = [os.path.join(data_dir, x) for x in file_paths]

resample_rate = 12000
dataset = MP3Dataset(file_paths,resample_rate)

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

mps


In [15]:
import torch
import librosa
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()

        channels = 8
        self.encoder = nn.Sequential(
            nn.Conv1d(1, channels, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm1d(channels),
            nn.ReLU(),
            nn.Conv1d(channels, channels, kernel_size=3, stride=4, padding=0),
            nn.BatchNorm1d(channels),
            nn.ReLU(),
            nn.Conv1d(channels, channels, kernel_size=3, stride=4, padding=1),
            nn.BatchNorm1d(channels),
            nn.ReLU(),
        )

        
        self.flatten = nn.Flatten(1)
        self.enc_mu = nn.Linear(11250*channels, 1000)
        self.enc_logvar = nn.Linear(11250*channels, 1000)
        

        self.decoder = nn.Sequential(
            nn.Linear(1000,11250*channels),
            nn.Unflatten(1,(channels,11250)),
            nn.ConvTranspose1d(channels, channels, kernel_size=3, stride=4),
            nn.BatchNorm1d(channels),
            nn.ReLU(),
            nn.ConvTranspose1d(channels, channels, kernel_size=3, stride=4, padding=0,output_padding=3),
            nn.ReLU(),
            nn.ConvTranspose1d(channels, 1, kernel_size=3, stride=1, padding=0),
            nn.ReLU()
        )
        # Decoder
        #self.fc3 = nn.Linear(latent_size, hidden_size)
        #self.fc4 = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        encoder_output = self.encoder(x)
        encoder_output = self.flatten(encoder_output)
        return self.enc_mu(encoder_output), self.enc_logvar(encoder_output)


    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return z

    def decode(self, z):
        decoder_output = self.decoder(z)
        return decoder_output

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
class loss_function(nn.Module):
    def __init__(self):
        super(loss_function,self).__init__()

    def forward(self, recon_x, x, mu, logvar):
        MSE = F.mse_loss(recon_x,x)
        
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

        #KLD = - 0.5 * torch.sum(1 + logvar - mu**2 - torch.exp(logvar))
        #print(KLD.shape)

        KLD = - 0.5 * torch.sum(1 + logvar - (torch.pow(mu, 2) + torch.exp(logvar)))
        #print(KLD.shape)
        return MSE + KLD

# Initialize the VAE
input_size = 180000 # MNIST images are 28x28 = 784
hidden_size = 400
latent_size = 20


vae = VAE(input_size, hidden_size, latent_size)
vae.to(device)

# Define the optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-5) # 1e-5 seems good
VAEloss = loss_function()


#input_size = resample_rate*15 # 15s per sample

summary(vae,input_size=[64,1,input_size])

Layer (type:depth-idx)                   Output Shape              Param #
VAE                                      [64, 1, 180000]           --
├─Sequential: 1-1                        [64, 8, 11250]            --
│    └─Conv1d: 2-1                       [64, 8, 179998]           32
│    └─BatchNorm1d: 2-2                  [64, 8, 179998]           16
│    └─ReLU: 2-3                         [64, 8, 179998]           --
│    └─Conv1d: 2-4                       [64, 8, 44999]            200
│    └─BatchNorm1d: 2-5                  [64, 8, 44999]            16
│    └─ReLU: 2-6                         [64, 8, 44999]            --
│    └─Conv1d: 2-7                       [64, 8, 11250]            200
│    └─BatchNorm1d: 2-8                  [64, 8, 11250]            16
│    └─ReLU: 2-9                         [64, 8, 11250]            --
├─Flatten: 1-2                           [64, 90000]               --
├─Linear: 1-3                            [64, 1000]                90,001,000
├─Lin

In [16]:
## Training Loop
print(device)
vae.train()

print(next(vae.parameters()).device)
vae.to(device)
print(next(vae.parameters()).device)
n_epochs = 1000
for epoch in range(n_epochs):
    batch = next(iter(dataloader))

    decoded, mu, logvar = vae(batch)

    loss = VAEloss(decoded,batch,mu,logvar)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("loss: ", loss)

mps
cpu
mps:0
loss:  tensor(6642.7026, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(5328.0234, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(5349.5464, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(5011.0840, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(4550.4814, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(4425.6494, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(4286.2031, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(4170.0752, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3656.0632, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3782.5752, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3570.0342, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3444.2532, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3513.3032, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3098.4033, device='mps:0', grad_fn=<AddBackward0>)
loss:  tensor(3080.0901, device='mps:0', grad_fn=<AddBackward0>)
loss:  tens

KeyboardInterrupt: 

In [20]:
vae.eval()

VAE(
  (encoder): Sequential(
    (0): Conv1d(1, 4, kernel_size=(3,), stride=(1,))
    (1): ReLU()
    (2): Conv1d(4, 4, kernel_size=(3,), stride=(4,))
    (3): ReLU()
    (4): Conv1d(4, 4, kernel_size=(3,), stride=(4,), padding=(1,))
    (5): ReLU()
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (enc_mu): Linear(in_features=45000, out_features=1000, bias=True)
  (enc_logvar): Linear(in_features=45000, out_features=1000, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=1000, out_features=45000, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(4, 11250))
    (2): ConvTranspose1d(4, 4, kernel_size=(3,), stride=(4,))
    (3): ReLU()
    (4): ConvTranspose1d(4, 4, kernel_size=(3,), stride=(4,), output_padding=(3,))
    (5): ReLU()
    (6): ConvTranspose1d(4, 1, kernel_size=(3,), stride=(1,))
    (7): ReLU()
  )
)

In [21]:
torch.save(vae.state_dict(), './models/model_002.pth')

In [17]:
vae.eval()
test_output = (next(iter(dataset)))
print(test_output.shape)

torch.Size([1, 180000])


In [18]:
model = vae.eval().to('cpu')

test_output = (next(iter(dataset)))
test_output = test_output.unsqueeze(0).to('cpu')
test_output = model(test_output)[0]
test_output = test_output.squeeze(0)
print(test_output.shape)
test_output = Resample(resample_rate,48000)(test_output)
torchaudio.save("./test_output.wav",test_output,48000)


torch.Size([1, 180000])
