# Everything Data

In [103]:
### IMPORTS ###
import os
import requests
import zipfile
import shutil
import os
from tqdm import tqdm
import librosa
import soundfile as sf
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import sentencepiece
import gc
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T

In [95]:
### CONSTANTS

SR = int(8e3)
SRkHz = int(SR//1e3)
VOCAB_SIZE = int(4e3)
MAX_CLIP_SECS = 2
MAX_WAV_LEN = int(SR*MAX_CLIP_SECS)

PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
UNK_ID = 3

AUDIO_PAD_ID = -2.0

## Download and process data

In [78]:
def download_vctk():  
    # Define the URL and the target paths
    url = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip'
    data_dir = './data/VCTK/raw'
    download_path = os.path.join(data_dir, 'VCTK-Corpus-0.92.zip')
    extract_path = os.path.join(data_dir, 'VCTK')

    # Ensure the data directory exists
    os.makedirs(data_dir, exist_ok=True)

    # Download the dataset
    print(f"Downloading VCTK dataset from {url}...")
    response = requests.get(url, stream=True)
    with open(download_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    print("Download complete.")

    # Unzip the file
    print(f"Extracting {download_path} to {data_dir}...")
    with zipfile.ZipFile(download_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    print("Extraction complete.")

    # Find the extracted folder and rename it to "VCTK"
    extracted_folder_name = 'VCTK-Corpus-0.92'
    original_extract_path = os.path.join(data_dir, extracted_folder_name)

    if os.path.exists(original_extract_path):
        os.rename(original_extract_path, extract_path)
        print(f"Renamed {original_extract_path} to {extract_path}")
    else:
        print(f"Expected extracted folder {original_extract_path} not found")

    print(f"VCTK dataset is ready at {extract_path}")

In [79]:
def process_data(target_sample_rate):
    # Define paths and target sample rate
    input_dir = './data/VCTK/raw/wav48_silence_trimmed'
    output_dir = './data/VCTK/raw/wav{}'.format(int(target_sample_rate // 1e3))  
    
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Collect all files to process
    files_to_process = []
    for root, dirs, files in os.walk(input_dir):
        for file in files:
            if file.endswith("_mic1.flac"):
                files_to_process.append((root, file))

    # Process files with a progress bar
    for root, file in tqdm(files_to_process, desc="Processing files", unit="file"):
        # Construct full file path
        file_path = os.path.join(root, file)

        # Load the audio file using librosa
        audio, sr = librosa.load(file_path, sr=None)

        # Downsample the audio file to the target sample rate
        audio_resampled = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)

        # Remove '_mic1' from the file name and change extension to .wav
        new_file_name = file.replace('_mic1.flac', '.wav')

        # Construct the output file path
        relative_path = os.path.relpath(file_path, input_dir)
        relative_dir = os.path.dirname(relative_path)
        output_file_path = os.path.join(output_dir, relative_dir, new_file_name)
        output_file_dir = os.path.dirname(output_file_path)
        os.makedirs(output_file_dir, exist_ok=True)

        # Export the downsampled audio file as a .wav file using soundfile
        sf.write(output_file_path, audio_resampled, target_sample_rate)

In [80]:
if not os.path.exists("./data/VCTK/raw/wav48_silence_trimmed"):
    download_vctk()

In [81]:
# Process the audio files

if not os.path.exists("./data/VCTK/raw/wav{}".format(SRkHz)):
    process_data(SR)

## Make Dataset

In [82]:
def read_speaker_info():
    speaker_info_path = './data/VCTK/raw/speaker-info.txt'
    speaker_info = {}
    with open(speaker_info_path, 'r') as file:
        lines = file.readlines()[1:]  # Skip the header
        for line in lines:
            parts = line.strip().split()
            speaker_id = parts[0]
            age = parts[1]
            gender = parts[2]
            accent = parts[3]
            region = parts[4] if len(parts) > 4 else ""
            comment = " ".join(parts[5:]) if len(parts) > 5 else ""
            speaker_info[speaker_id] = {
                "age": age,
                "gender": gender,
                "accent": accent,
                "region": region,
                "comment": comment,
            }
    return speaker_info

def create_dataset(target_sample_rate):
    # Define paths
    wav_dir = f'./data/VCTK/raw/wav{int(target_sample_rate // 1e3)}'
    txt_dir = './data/VCTK/raw/txt'
    speaker_info = read_speaker_info()
    
    dataset = []

    files_to_process = []
    for root, dirs, files in os.walk(wav_dir):
        for file in files:
            if file.endswith(".wav"):
                files_to_process.append((root, file))

    for root, file in tqdm(files_to_process, desc="Creating dataset", unit="file"):
        file_path = os.path.join(root, file)
        file_name = os.path.basename(file)
        speaker_id, text_id = file_name.split("_")[0], file_name.split("_")[1].split(".")[0]
        text_file_path = os.path.join(txt_dir, speaker_id, "{}_{}.txt".format(speaker_id, text_id))
        
        # Check if the text file exists
        if not os.path.exists(text_file_path):
            print(f"Text file not found for {file_name}, skipping...")
            continue
        
        with open(text_file_path, 'r') as text_file:
            text = text_file.read().strip()
        
        speaker_meta = speaker_info.get(speaker_id, {})
        entry = {
            "speaker_id": speaker_id,
            "text": text,
            "path": file_path
        }
        dataset.append(entry)
        
    df = pd.DataFrame(dataset)
    train_df, val_df, test_df = split_dataset(df)

    train_df.to_csv('train_{}.csv'.format(int(target_sample_rate//1e3)))
    val_df.to_csv('val_{}.csv'.format(int(target_sample_rate//1e3)))
    test_df.to_csv('test_{}.csv'.format(int(target_sample_rate//1e3)))
    
    
    return train_df, val_df, test_df
    
def split_dataset(df, train_size=0.7, val_size=0.15, test_size=0.15, random_state=42):
    # Ensure the split proportions sum to 1
    assert train_size + val_size + test_size == 1.0, "Train, validation, and test sizes must sum to 1.0"
    
    # Get unique speakers
    speakers = df['speaker_id'].unique()
    
    # Split speakers into train and temp (val + test)
    train_speakers, temp_speakers = train_test_split(speakers, train_size=train_size, random_state=random_state)
    
    # Calculate the proportion for validation in the temp split
    val_proportion = val_size / (val_size + test_size)
    
    # Split temp_speakers into validation and test sets
    val_speakers, test_speakers = train_test_split(temp_speakers, train_size=val_proportion, random_state=random_state)
    
    # Assign entries to the respective sets
    train_df = df[df['speaker_id'].isin(train_speakers)]
    val_df = df[df['speaker_id'].isin(val_speakers)]
    test_df = df[df['speaker_id'].isin(test_speakers)]
    
    return train_df, val_df, test_df

In [83]:
def load_split(target_sample_rate):
#     train_data = load_dataset_hdf5("./data/VCTK/processed/train_{}.h5".format(int(target_sample_rate // 1e3)))
#     val_data = load_dataset_hdf5("./data/VCTK/processed/val_{}.h5".format(int(target_sample_rate // 1e3)))
#     test_data = load_dataset_hdf5("./data/VCTK/processed/test_{}.h5".format(int(target_sample_rate // 1e3)))  
    
#     train_df = pd.DataFrame(train_data)
#     val_df = pd.DataFrame(val_data)    
#     test_df = pd.DataFrame(test_data) 
    
    train_df = pd.read_csv("./data/VCTK/processed/train_{}.csv".format(int(target_sample_rate // 1e3)))
    val_df = pd.read_csv("./data/VCTK/processed/val_{}.csv".format(int(target_sample_rate // 1e3)))   
    test_df = pd.read_csv("./data/VCTK/processed/test_{}.csv".format(int(target_sample_rate // 1e3)))
    
    return train_df, val_df, test_df


In [84]:
### ACTUALLY LOAD/CREATE DATAFRAMES

return_dfs = True

if (not os.path.exists("./data/VCTK/processed/train_{}.csv".format(SRkHz)) or
    not os.path.exists("./data/VCTK/processed/val_{}.csv".format(SRkHz)) or
    not os.path.exists("./data/VCTK/processed/test_{}.csv".format(SRkHz))):
    
    # Create dataset
    train_df, val_df, test_df = create_dataset(SR)
elif return_dfs:
    # Load from data
    train_df, val_df, test_df = load_split(SR)
    

Creating dataset:  96%|█████████▌| 42751/44455 [00:08<00:00, 5268.22file/s]

Text file not found for p315_009.wav, skipping...
Text file not found for p315_380.wav, skipping...
Text file not found for p315_343.wav, skipping...
Text file not found for p315_357.wav, skipping...
Text file not found for p315_196.wav, skipping...
Text file not found for p315_141.wav, skipping...
Text file not found for p315_154.wav, skipping...
Text file not found for p315_183.wav, skipping...
Text file not found for p315_418.wav, skipping...
Text file not found for p315_020.wav, skipping...
Text file not found for p315_022.wav, skipping...
Text file not found for p315_208.wav, skipping...
Text file not found for p315_397.wav, skipping...
Text file not found for p315_368.wav, skipping...
Text file not found for p315_142.wav, skipping...
Text file not found for p315_221.wav, skipping...
Text file not found for p315_209.wav, skipping...
Text file not found for p315_023.wav, skipping...
Text file not found for p315_027.wav, skipping...
Text file not found for p315_033.wav, skipping...


Creating dataset: 100%|██████████| 44455/44455 [00:09<00:00, 4928.57file/s]


In [85]:
test_df.to_csv("./data/VCTK/processed/test_8.csv")

In [86]:
unique_speakers = (train_df['speaker_id'].unique().tolist() + 
                    val_df['speaker_id'].unique().tolist() + 
                    test_df['speaker_id'].unique().tolist())

speaker_to_idx = {speaker: idx for idx, speaker in enumerate(unique_speakers)}

In [87]:
text = (train_df['text'].tolist() + 
        val_df['text'].tolist() + 
        test_df['text'].tolist())

In [88]:
### CREATE TOKENIZER

args = {
    "pad_id": PAD_ID,
    "bos_id": BOS_ID,
    "eos_id": EOS_ID,
    "unk_id": UNK_ID,
    "input": "./data/VCTK/raw/text.txt",
    "vocab_size": VOCAB_SIZE,
    "model_prefix": "Multi30k",
    # "model_type": "word",
}
combined_args = " ".join(
    "--{}={}".format(key, value) for key, value in args.items())
sentencepiece.SentencePieceTrainer.Train(combined_args)

vocab = sentencepiece.SentencePieceProcessor()
vocab.Load("Multi30k.model")

sentencepiece_trainer.cc(178) LOG(INFO) Running command: --pad_id=0 --bos_id=1 --eos_id=2 --unk_id=3 --input=./data/VCTK/raw/text.txt --vocab_size=4000 --model_prefix=Multi30k
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: ./data/VCTK/raw/text.txt
  input_format: 
  model_prefix: Multi30k
  model_type: UNIGRAM
  vocab_size: 4000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vo

True

In [89]:
# print("Vocabulary size:", vocab.GetPieceSize())
# print()

# for example in text[:3]:
#   sentence = example
#   pieces = vocab.EncodeAsPieces(sentence)
#   indices = vocab.EncodeAsIds(sentence)
#   print(sentence)
#   print(pieces)
#   print(vocab.DecodePieces(pieces))
#   print(indices)
#   print(vocab.DecodeIds(indices))
#   print()

# piece = vocab.EncodeAsPieces("the")[0]
# index = vocab.PieceToId(piece)
# print(piece)
# print(index)
# print(vocab.IdToPiece(index))

In [90]:
class VCTK(Dataset):
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        """
        What i want the code to do is as follows:
        
        - select target information
        - clip target audio to MAX_WAVE_LEN with random start position
        - clip refrence audio to MAX_WAVE_LEN with another random start position
        """
        
        # get target information
        row = self.df.iloc[idx]
        
        og_audio, sr = librosa.load(row['path'], sr=SR)
        og_audio = torch.from_numpy(og_audio)
        
        text = row['text']
        tokens = torch.tensor(vocab.EncodeAsIds(text))
        
        speaker_id = speaker_to_idx[row['speaker_id']]
        
        # get refrence information b-vae way --- ADHERE TO THESE COMMENTS
        ## get max wave len random section of audio 
        ## get another max wav len random section of audio -- this is ref audio
        if len(og_audio) > MAX_WAV_LEN:
            start_idx = np.random.randint(0, len(og_audio) - MAX_WAV_LEN)
            audio = og_audio[start_idx:start_idx + MAX_WAV_LEN]
            ref_start_idx = np.random.randint(0, len(og_audio) - MAX_WAV_LEN)
            ref_audio = og_audio[ref_start_idx:ref_start_idx + MAX_WAV_LEN]
        else:
            audio = og_audio
            ref_audio = og_audio
        
        
        
        
        # get refrence information styletts way --- IGNORE THIS SECTION
        # ref_row = self.df[self.df['speaker_id'] == row['speaker_id']].sample(1).iloc[0]
        
        # ref_audio, sr = librosa.load(ref_row['path'], sr=SR)
        # ref_audio = torch.from_numpy(ref_audio)
        
        # ref_text = ref_row['text']
        # ref_tokens = torch.tensor(vocab.EncodeAsIds(ref_text))
        
        # ref_speaker_id = speaker_to_idx[ref_row['speaker_id']]
        
        
        
        
        sample = {
            'audio': audio,
            'tokens': tokens,  # Token IDs
            'speaker_id': torch.tensor(speaker_id, dtype=torch.long),  # Numeric speaker ID
            'ref_audio': ref_audio,
            # 'ref_tokens': ref_tokens,  # Token IDs
            # 'ref_speaker_id': torch.tensor(ref_speaker_id, dtype=torch.long),
        }
        
        return sample
    
train_dataset = VCTK(train_df)

val_dataset = VCTK(val_df)

test_dataset = VCTK(test_df)

In [91]:
def Collate(batch):
    # Separate the batch into individual lists
    audios = [item['audio'] for item in batch]
    ref_audios = [item['ref_audio'] for item in batch]
    tokens = [item['tokens'] for item in batch]
    speaker_ids = [item['speaker_id'] for item in batch]

    # Pad the audio and token sequences
    padded_audios = pad_sequence(audios, batch_first=True, padding_value=AUDIO_PAD_ID)
    padded_ref_audios = pad_sequence(ref_audios, batch_first=True, padding_value=AUDIO_PAD_ID)
    padded_tokens = pad_sequence(tokens, batch_first=True, padding_value=AUDIO_PAD_ID)

    # Stack speaker IDs
    speaker_ids = torch.stack(speaker_ids)

    # Create the batch dictionary
    batch = {
        'audio': padded_audios,
        'ref_audio': padded_ref_audios,
        'tokens': padded_tokens,
        'speaker_id': speaker_ids
    }

    return batch

In [92]:
# Create the DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=Collate, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=Collate, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=Collate, pin_memory=True)

## Training Loop

In [None]:
### TRAINING HYPER PARAMS

epochs = 1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cpu


In [None]:
class SpeakerEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(SpeakerEncoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.conv_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.AvgPool1d(kernel_size=2)
            ) for _ in range(4)
        ])
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        for conv_block in self.conv_blocks:
            x = conv_block(x)
        x = torch.mean(x, dim=-1)  # Global average pooling
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class ContentEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(ContentEncoder, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.attention_blocks = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=4) for _ in range(2)
        ])
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.dropout(x, 0.2)
        x = F.relu(self.conv2(x))
        x = F.dropout(x, 0.2)
        for attention_block in self.attention_blocks:
            x, _ = attention_block(x, x, x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.attention_blocks = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=4) for _ in range(2)
        ])
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.postnet = nn.Sequential(
            nn.Conv1d(output_dim, output_dim, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(output_dim, output_dim, kernel_size=5, padding=2),
        )
    
    def forward(self, zc, zs):
        x = torch.cat((zc, zs), dim=-1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        for attention_block in self.attention_blocks:
            x, _ = attention_block(x, x, x)
        x = self.fc(x)
        x = self.postnet(x) + x  # Residual connection
        return x

class BetaVAEVC(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(BetaVAEVC, self).__init__()
        self.speaker_encoder = SpeakerEncoder(input_dim, hidden_dim, latent_dim)
        self.content_encoder = ContentEncoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim * 2, hidden_dim, output_dim)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu_c, logvar_c = self.content_encoder(x)
        mu_s, logvar_s = self.speaker_encoder(x)
        zc = self.reparameterize(mu_c, logvar_c)
        zs = self.reparameterize(mu_s, logvar_s)
        recon_x = self.decoder(zc, zs)
        return recon_x, mu_c, logvar_c, mu_s, logvar_s

# Example usage
input_dim = 1  # Mel-spectrogram dimension
hidden_dim = 256
latent_dim = 128
output_dim = 1  # Reconstructed Mel-spectrogram dimension

model = BetaVAEVC(input_dim, hidden_dim, latent_dim, output_dim)

In [104]:
class MultiScaleMelLoss(nn.Module):
    def __init__(self):
        super(MultiScaleMelLoss, self).__init__()
        self.window_lengths = [32, 64, 128, 256, 512, 1024, 2048]
        self.hop_lengths = [wl // 4 for wl in self.window_lengths]
        self.mel_bin_sizes = [5, 10, 20, 40, 80, 160, 320]
        self.loss_fn = nn.L1Loss(reduction='none')  # Use reduction='none' to apply masking later

    def forward(self, recon_waveform, target_waveform, mask):
        total_loss = 0
        for wl, hl, mel_bins in zip(self.window_lengths, self.hop_lengths, self.mel_bin_sizes):
            mel_transform = T.MelSpectrogram(
                sample_rate=SR,
                n_fft=wl,
                hop_length=hl,
                n_mels=mel_bins
            )
            recon_mel = mel_transform(recon_waveform)
            target_mel = mel_transform(target_waveform)
            
            # Extend mask to match mel-spectrogram dimensions
            mel_mask = mask.unsqueeze(1).expand_as(recon_mel)
            
            # Apply masking to the loss
            loss = self.loss_fn(recon_mel, target_mel)
            loss = loss * mel_mask
            total_loss += loss.sum() / mel_mask.sum()  # Normalize by the number of unmasked elements

        return total_loss

# Example usage
loss_fn = MultiScaleMelLoss()
recon_waveform = torch.randn(2, 32000)  # Example reconstructed waveform
target_waveform = torch.randn(2, 32000)  # Example target waveform
mask = torch.ones_like(recon_waveform, dtype=torch.bool)  # Example mask

loss = loss_fn(recon_waveform, target_waveform, mask,)
print(f'Multi-Scale Mel Loss: {loss.item()}')

RuntimeError: The expanded size of the tensor (4001) must match the existing size (32000) at non-singleton dimension 2.  Target sizes: [2, 5, 4001].  Tensor sizes: [2, 1, 32000]

In [None]:
for epoch in range(epochs):
    for batch in tqdm(train_loader):
        features = batch.values()
        features = [f.to(device) for f in features]
        audio, ref_audio, tokens, speaker_id = features
        
        token_mask = tokens == PAD_ID
        audio_mask = audio_mask == AUDIO_PAD_ID
        ref_audio_mask = ref_audio_mask == AUDIO_PAD_ID
        
        content_encoding = model.content_encoder(audio, audio_mask)
        speaker_encoding = model.speaker_encoder(ref_audio, ref_audio_mask)
        pitch_feature = model.pitch_extractor(audio, ref_audio, audio_mask, ref_audio_mask)
        
        new_wave, content_encoding = model.decode(content_encoding, speaker_encoding, pitch_feature)
        
        rec_loss = multi_mel_loss(audio, new_wave, audio_mask)
        # add KL loss from B-VAE
        # maybe add contrastive loss from content vec -- requires transforms
        # maybe add vector quantization loss from GRVQ or VQMIVC
        # maybe add MI loss from VQMIVC
        
        
        
        
        print(ref_audio.size())
        break

  0%|          | 0/960 [00:00<?, ?it/s]

torch.Size([32, 16000])





In [None]:
# Start on Training process
## Get train loop working
## Make loss functions
## Design models

In [None]:
# Work on Inference process
## make Inference dataloaders ? 
## make Inference pipeline