In [14]:
import torch
import torchaudio
from torch.utils.data import Dataset

import numpy as np
import os # for file path manipulation

import csv # for reading tsv files


# custom dataset class
class SpeechDataset(Dataset):
    def __init__(self, tsvs=[], sample_rate=16000, transform=None, columns=['path']):
        self.tsvs = tsvs
        self.sample_rate = sample_rate
        self.transform = transform
        self.columns = columns
        self.data = []

        # load metadata
        self._load_metadata()


    def split(self, split_ratio=0.8):
        # split data into train and test sets
        # split_ratio is the ratio of training data to test data
        # returns two SpeechDataset objects, one for train and one for test

        # get split index
        split_idx = int(len(self.data) * split_ratio)

        # split data
        train_data = self.data[:split_idx]
        test_data = self.data[split_idx:]

        # create new SpeechDataset objects
        train_dataset = SpeechDataset(sample_rate=self.sample_rate, transform=self.transform, columns=self.columns)
        test_dataset = SpeechDataset(sample_rate=self.sample_rate, transform=self.transform, columns=self.columns)

        # set data
        train_dataset.data = train_data
        test_dataset.data = test_data

        return train_dataset, test_dataset

    def _load_metadata(self):
        self.data = []
        for tsv in self.tsvs:
            dir_path, _ = os.path.split(tsv)
            
            clips = os.path.join(dir_path, 'clips', '')
            
            # read tsv and append to data
            with open(tsv, 'r') as f:
                reader = csv.DictReader(f, delimiter='\t')
                for row in reader:
                    # commonvoice columns:
                    # client_id	path	sentence	up_votes	down_votes	age	gender	accents	variant	locale	segment
                    
                    # get columns
                    data = [row[col] for col in self.columns]
                    if 'path' in self.columns:
                        # convert path to absolute path
                        path_idx = self.columns.index('path')
                        data[path_idx] = clips + data[path_idx]
                    # append to data
                    self.data.append(data)


        # shuffle data
        np.random.shuffle(self.data)

    def get_column_names(self):
        # if path is included, last column is audio data that will be loaded in __getitem__
        if 'path' in self.columns:
            # self.columns + ['audio']
            return self.columns + ['audio']
        else:
            return self.columns
        

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # load data
        sample = self.data[idx]
        # load audio (if path is in sample)
        if 'path' in self.columns:
            # load audio
            #print(sample[self.columns.index('path')])
            audio, sample_rate = torchaudio.load(sample[self.columns.index('path')])
            
            # resample audio if necessary
            if sample_rate != self.sample_rate:
                resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
                audio = resampler(audio)

                # normalize audio
                audio = audio / torch.max(torch.abs(audio))

            
            # add audio to sample
            sample.append(audio)

        # apply transform if necessary
        if self.transform:
            sample = self.transform(sample)

        return sample


In [15]:
dataset = SpeechDataset(tsvs=[
    'commonvoice\\cv-corpus-16.0-delta-2023-12-06\\en\\validated.tsv',
    'commonvoice\\cv-corpus-16.0-delta-2023-12-06\\de\\validated.tsv', 
    'commonvoice\\cv-corpus-16.0-delta-2023-12-06\\ja\\validated.tsv'], columns=['path', 'sentence'])

print('Dataset length:', len(dataset))
print('Dataset columns:', dataset.get_column_names())

import random
# get first sample
sample = dataset[random.randint(0, len(dataset))]
print('Sample:', sample)

# get audio from sample
audio = sample[-1]

# play audio
from IPython.display import Audio


# Play the audio using IPython's Audio widget
audio_widget = Audio(data=audio.numpy(), rate=16000)
display(audio_widget)

Dataset length: 16894
Dataset columns: ['path', 'sentence', 'audio']
Sample: ['commonvoice\\cv-corpus-16.0-delta-2023-12-06\\de\\clips\\common_voice_de_38561839.mp3', 'Die Stadt bezieht ihr Leitungswasser aus einem großen See.', tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.1659e-07,
         -1.3548e-07, -1.3411e-07]])]


In [3]:

# baseline models
import torch.nn as nn

SAMPLE_RATE = 16000

def pad_batch(batch):
    if isinstance(batch[0], list):
        # if batch is list of list, get tensor from last element
        batch = [sample[-1].reshape(-1) for sample in batch]
    # pads batch to longest sequence
    # batch is list of samples
    lengths = [len(sample) for sample in batch]
    max_length = max(lengths)
    max_length = (max_length // (SAMPLE_RATE // 100) + 1) * (SAMPLE_RATE // 100)
    # pad to max length
    padded_batch = [torch.nn.functional.pad(sample, (0, max_length - len(sample))) for sample in batch]
    return torch.stack(padded_batch)

class BaselineEmbedder(nn.Module):
    def __init__(self, sample_rate = SAMPLE_RATE, embedding_dim=32):
        super(BaselineEmbedder, self).__init__()
        self.sample_rate = sample_rate
        self.embedding_dim = embedding_dim

        # lstm layers
        self.lstm = nn.LSTM(input_size=1, hidden_size=embedding_dim, num_layers=3, batch_first=True)

    
    def forward(self, x):
        # x is audio, clips are padded to longest sequence
        # x is (batch_size, samples)

        # reshape to (batch_size, samples, 1)
        x = x.unsqueeze(2)
        x = self.lstm(x)
        # get last hidden state
        x = x[0][:, -1, :]
        x = x.reshape(-1, self.embedding_dim)
        return x
    


In [4]:
baseline = BaselineEmbedder()
print(baseline)

batch = [dataset[random.randint(0, len(dataset))][-1] for _ in range(16)]
batch = [sample[-1] for sample in batch]
batch = pad_batch(batch)

print('Input shape:', batch.shape)

# get embeddings
embeddings = baseline(batch)
print('Embeddings shape:', embeddings.shape)


BaselineEmbedder(
  (lstm): LSTM(1, 32, num_layers=3, batch_first=True)
)
Input shape: torch.Size([16, 163040])
Embeddings shape: torch.Size([16, 32])


In [25]:
# VAE and decoder

class print_shape(nn.Module):
    def __init__(self, message):
        super().__init__()
        self.message = message
    
    def forward(self, x):
        print(self.message, x.shape)
        return x

class VAE(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.training = True
        # encoder
        # input: audio waveform 16000 samples
        # latent space: 32 dimensions, 100 samples
        # goal: learn latent space representation of audio that is easier to use in RNNs

        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=20, stride=10, padding=5),
            # samples 16,000 -> 1,600
            nn.ReLU(),
            #nn.Linear(in_features=16, out_features=16),
            #nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=8, stride=4, padding=2),
            # samples 1,600 -> 400
            nn.ReLU(),
            #nn.Linear(in_features=32, out_features=32),
            #nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
            # samples 400 -> 200
            nn.ReLU(),
            #nn.Linear(in_features=32, out_features=32),
            #nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            # samples 200 -> 100
            #nn.ReLU(),
            #nn.Linear(in_features=64, out_features=64),
            nn.Tanh()
        )

        # decoder
        # input: latent space representation
        # output: audio waveform 16000 samples
        # goal: reconstruct original audio waveform from latent space representation

        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
            # samples 100 -> 200
            nn.ReLU(),
            #nn.Linear(in_features=32, out_features=32),
            #nn.ReLU(),
            nn.ConvTranspose1d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
            # samples 200 -> 400
            nn.ReLU(),
            #nn.Linear(in_features=32, out_features=32),
            #nn.ReLU(),
            nn.ConvTranspose1d(in_channels=32, out_channels=16, kernel_size=8, stride=4, padding=2),
            # samples 400 -> 1600
            nn.ReLU(),
            #nn.Linear(in_features=16, out_features=16),
            #nn.ReLU(),
            nn.ConvTranspose1d(in_channels=16, out_channels=1, kernel_size=20, stride=10, padding=5),
            # samples 1600 -> 16000
            nn.Tanh()
        )

    def set_training(self, training):
        self.training = training

    def sample(self, mu, log_var):
        # if not self.training:
        #     return mu
        if not self.training:
            return mu
        # reparameterization trick
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):

        # reshape to (n, 1, length)
        x = x.unsqueeze(1)
        # encode
        x = self.encoder(x)
        # get mu and log_var
        mu = x[:, :32]
        log_var = x[:, 32:]
        # sample from latent space
        z = self.sample(mu, log_var)
        # decode
        x = self.decoder(z)
        if not self.training:
            return x
        
        
        return x, mu, log_var

In [6]:
vae = VAE()
print(vae)

# get batch
batch = [dataset[random.randint(0, len(dataset))][-1] for _ in range(16)]
batch = [sample[-1] for sample in batch]
batch = pad_batch(batch)


print('Input shape:', batch.shape)

# get output
output = vae(batch)
print('Output shape:', output.shape)



VAE(
  (encoder): Sequential(
    (0): Conv1d(1, 16, kernel_size=(20,), stride=(10,), padding=(5,))
    (1): ReLU()
    (2): Conv1d(16, 32, kernel_size=(8,), stride=(4,), padding=(2,))
    (3): ReLU()
    (4): Conv1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    (5): ReLU()
    (6): Conv1d(32, 64, kernel_size=(4,), stride=(2,), padding=(1,))
    (7): ReLU()
  )
  (decoder): Sequential(
    (0): ConvTranspose1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    (1): ReLU()
    (2): ConvTranspose1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    (3): ReLU()
    (4): ConvTranspose1d(32, 16, kernel_size=(8,), stride=(4,), padding=(2,))
    (5): ReLU()
    (6): ConvTranspose1d(16, 1, kernel_size=(20,), stride=(10,), padding=(5,))
    (7): ReLU()
  )
)
Input shape: torch.Size([16, 153920])
Output shape: torch.Size([16, 1, 153920])


In [30]:
# train vae
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

# hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
EPOCHS = 3
kl_beta = 0.1

# create dataloader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)

# create model
vae = VAE()
vae.train()
# create optimizer
optimizer = optim.Adam(vae.parameters(), lr=LEARNING_RATE)

# train

def print_progress(epoch, batch, loss):
    prog = batch / len(dataloader)
    prog = int(prog * 20)
    print(f'Epoch: {epoch} | {"#" * prog}{"-" * (20 - prog)} | Loss: {loss.item()}', end='\r')

for epoch in range(EPOCHS):
    for i, batch in enumerate(dataloader):
        # zero gradients
        optimizer.zero_grad()

        # forward pass
        output, mu, var = vae(batch)
        output = output.squeeze(1)

        # calculate loss
        reconstruction_loss = torch.nn.functional.mse_loss(output, batch)
        # KL divergence
        kl_divergence = -0.5 * torch.sum(1 + var - mu.pow(2) - var.exp())
        # total loss
        loss = reconstruction_loss

        # backward pass
        loss.backward()

        # update weights
        optimizer.step()

        # print progress
        print_progress(epoch, i, loss)
    print()

# save model
torch.save(vae.state_dict(), 'weights/vae.pth')

Epoch: 0
Epoch: 1 | ###################- | Loss: 0.32380279898643494
Epoch: 2 | ###################- | Loss: 0.17883901298046112
Epoch: 2 | ###################- | Loss: 0.10987542569637299

In [36]:
#vae = VAE()
vae.set_training(False)
#vae.load_state_dict(torch.load('weights/vae.pth'))

vae.eval()

# test vae
# get sample
sample = dataset[random.randint(0, len(dataset))]
audio = sample[-1]
print(audio[:100])
audio.reshape(1, -1)

# play before
audio_widget = Audio(data=audio.numpy()[0], rate=16000)
display(audio_widget)

# get reconstruction
output = vae(audio)

# play after
output_np = output.detach().numpy()[0]
output_np = output_np.reshape(-1)
print(output_np[:100])
audio_widget = Audio(data=output.detach().numpy()[0], rate=16000)
display(audio_widget)


tensor([[-1.0927e-12, -3.7364e-13, -4.0772e-12,  ..., -5.8550e-06,
          1.6587e-06,  6.4338e-06]])


[-0.15569158 -0.11841872 -0.13451362 -0.1494734  -0.10584262 -0.15234436
 -0.1516264  -0.14396578 -0.14900397 -0.18914557 -0.12512892 -0.10927171
 -0.11981604 -0.16183333 -0.12885489 -0.08925357 -0.07815778 -0.14749113
 -0.12288178 -0.14614207 -0.13512006 -0.07782374 -0.11443429 -0.1052385
 -0.08533059 -0.10006317 -0.10752019 -0.10135372 -0.08168423 -0.1262272
 -0.06532541 -0.08418865 -0.07823425 -0.10905566 -0.07291453 -0.09705845
 -0.10420202 -0.09458698 -0.09117529 -0.1096366  -0.08265574 -0.07489337
 -0.10463782 -0.10716467 -0.07434531 -0.08474568 -0.10143303 -0.0851464
 -0.08829395 -0.11799605 -0.07201254 -0.08722376 -0.08737589 -0.1104067
 -0.07953914 -0.07399999 -0.08102004 -0.09139003 -0.08332301 -0.10846446
 -0.08763192 -0.06878931 -0.08473125 -0.0948327  -0.06215488 -0.08479345
 -0.08899587 -0.08438909 -0.07098583 -0.10294123 -0.05105901 -0.06369541
 -0.06330512 -0.08339896 -0.06030434 -0.07873052 -0.08817818 -0.07353391
 -0.07026399 -0.08809669 -0.06110768 -0.06332246 -0.086