In [None]:
!pip install einops

In [None]:
import os
import pandas as pd 
import numpy as np
import shutil
from glob import glob
from tqdm.notebook import tqdm
from scipy.io import wavfile
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

import torchaudio

from einops import repeat
from einops.layers.torch import Rearrange

# downloading pre-trained models
import urllib.request

# trimming silences in .wav files
from pydub import AudioSegment

wandb_api_key = '430e8c7ef92cf79a3d7c3d02e3d961257153181f'
os.environ["WANDB_API_KEY"] = wandb_api_key

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#np.save('/kaggle/working/start-token.npy',torch.randn(1024).numpy())
START_TOKEN = torch.from_numpy(np.load('/kaggle/working/start-token.npy')).unsqueeze(0)

#np.save('/kaggle/working/stop-token.npy',torch.randn(1024).numpy())
STOP_TOKEN = torch.from_numpy(np.load('/kaggle/working/stop-token.npy')).unsqueeze(0)

In [None]:
source_content_path = '/kaggle/input/europarl-source-content-features'
target_content_path = '/kaggle/input/europarl-target-content-features'
target_speaker_path = '/kaggle/input/europarl-target-speaker-features'

In [None]:
two_channel_wavs = np.load('/kaggle/input/europarl-extra-files/two_channel_wavs.npy')

In [None]:
MAX_SEQUENCE_LENGTH = 350 + 1 # 1 tokens for SOS 

class CLVCDataset(Dataset):
    def __init__(self, source_content_path, target_content_path, target_speaker_path, max_sequence_length=270):
        self.max_sequence_length = max_sequence_length
        
        self.source_content_embeddings = glob(f'{source_content_path}/*.content.npy')
        self.target_content_embeddings = glob(f'{target_content_path}/*.content.npy')
        self.target_speaker_embeddings = glob(f'{target_speaker_path}/*.se.npy')
        
        self.source_content_embeddings = sorted(list(filter(lambda x: x.split('/')[-1].replace('.content.npy', '') not in two_channel_wavs, self.source_content_embeddings)))
        self.target_content_embeddings = sorted(list(filter(lambda x: x.split('/')[-1].replace('.content.npy', '') not in two_channel_wavs, self.target_content_embeddings)))
        self.target_speaker_embeddings = sorted(list(filter(lambda x: x.split('/')[-1].replace('.se.npy', '') not in two_channel_wavs, self.target_speaker_embeddings)))
        
        self._trim_dataset()
        
        assert len(self.target_content_embeddings) != 0, 'Length of content embeddings may not be zero.'
        assert len(self.target_content_embeddings) == len(self.target_speaker_embeddings), 'Target speaker content embeddings must be same length as target speaker embeddings'
        assert len(self.target_content_embeddings) == len(self.source_content_embeddings), 'Source content embeddings must be same length as target content embeddings'
        
    def __len__(self):
        return len(self.source_content_embeddings)
    
    def _trim_dataset(self):
        files_to_remove = np.load('/kaggle/input/europarl-extra-files/files_to_remove.npy')            
        self.source_content_embeddings = sorted(list(filter(lambda x: x.split('/')[-1].replace('.content.npy', '') not in files_to_remove, self.source_content_embeddings)))
        self.target_content_embeddings = sorted(list(filter(lambda x: x.split('/')[-1].replace('.content.npy', '') not in files_to_remove, self.target_content_embeddings)))
        self.target_speaker_embeddings = sorted(list(filter(lambda x: x.split('/')[-1].replace('.se.npy', '') not in files_to_remove, self.target_speaker_embeddings)))
        #self.files_to_remove = files_to_remove

    @staticmethod
    def pad_sequence(sequence, max_sequence_length):
        seq_len = sequence.shape[1]
        pad_len = max(0, max_sequence_length - seq_len)
        
        if pad_len > 0:
            sequence = F.pad(sequence, (0, 0, 0, pad_len, 0, 0))
            
        return sequence
    
    def __getitem__(self, idx):
        # load pre-computed embeddings
        
        source_content_embed = np.load(self.source_content_embeddings[idx])[:MAX_SEQUENCE_LENGTH, :]
        target_content_embed = np.load(self.target_content_embeddings[idx])[:MAX_SEQUENCE_LENGTH - 1, :] # subtract 1 to accomodate start token
        target_speaker_embed = np.load(self.target_speaker_embeddings[idx])
       
        # numpy array -> torch tensor
        target_content_embed = torch.from_numpy(target_content_embed).unsqueeze(0)
        target_speaker_embed = torch.from_numpy(target_speaker_embed)
        source_content_embed = torch.from_numpy(source_content_embed).unsqueeze(0)
        
        optim_target_content_embed = torch.concat((target_content_embed, STOP_TOKEN.unsqueeze(0)), dim=1)   
        target_content_embed = torch.concat((START_TOKEN.unsqueeze(0), target_content_embed), dim=1)
        
        # get lengths
        source_length = torch.tensor(source_content_embed.shape[1]) 
        target_length = torch.tensor(target_content_embed.shape[1])  
       
        # pad content sequences
        source_content_embed = self.pad_sequence(source_content_embed, self.max_sequence_length).squeeze(0)
        target_content_embed = self.pad_sequence(target_content_embed, self.max_sequence_length).squeeze(0)
        optim_target_content_embed = self.pad_sequence(optim_target_content_embed, self.max_sequence_length).squeeze(0)
    
        return source_content_embed, target_content_embed, optim_target_content_embed, target_speaker_embed, source_length, target_length

In [None]:
dataset = CLVCDataset(source_content_path=source_content_path,
                      target_content_path=target_content_path,
                      target_speaker_path=target_speaker_path,
                      max_sequence_length=MAX_SEQUENCE_LENGTH)

train_split = int(len(dataset) * 0.95)

train_dataset, val_dataset =  torch.utils.data.random_split(dataset, [train_split, len(dataset) - train_split])
train_loader = DataLoader(train_dataset, batch_size=48, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=48, shuffle=True)

### Modules

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [None]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, expansion_factor=2):
        super().__init__()
        
        self.target_norm = nn.LayerNorm(embed_dim)
        self.query_norm = nn.LayerNorm(embed_dim)
        self.memory_norm = nn.LayerNorm(embed_dim)
        
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.sa_dropout = nn.Dropout(dropout)
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),
            Rearrange('B S E -> B E S'),
            nn.Conv1d(embed_dim, int(embed_dim * expansion_factor), 9, padding = (9 - 1) // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv1d(int(embed_dim * expansion_factor), embed_dim, 1, padding = (1 - 1) // 2),
            nn.Dropout(dropout),
            Rearrange('B E S-> B S E')
        )
        
    def forward(self,
                query,
                memory,
                query_padding_mask=None,
                query_fill_mask=None,
                memory_padding_mask=None,
                memory_fill_mask=None
               ): 
        
        if query_fill_mask is not None:
            query.masked_fill_(query_fill_mask, 0.0)
            
        query_in = self.query_norm(query)
        memory_in = self.memory_norm(memory)
        
        sa_attn_out, _ = self.self_attention(query_in,
                                             memory_in,
                                             memory_in,
                                             key_padding_mask=memory_padding_mask)
        
        if query_fill_mask is not None:
            sa_attn_out.masked_fill_(query_fill_mask, 0.0)
            
        x = self.sa_dropout(sa_attn_out) + query
        x = self.mlp(self.target_norm(x)) + x
        
        if query_fill_mask is not None:
            x.masked_fill_(query_fill_mask, 0.0)
        
        return x       

In [None]:
class WavLMDecoder(nn.Module):
    def __init__(self, 
                 speaker_embedding_dim,
                 content_embedding_dim,
                 hidden_dim,
                 num_heads,
                 num_layers,
                 max_sequence_length,
                 expansion_factor=2,
                 dropout=0.0,
                 duration_dropout=0.0
                ):
        super().__init__()
        
        self.content_embedding_dim = content_embedding_dim 
        self.max_sequence_length = max_sequence_length
        self.hidden_dim = hidden_dim 
        
        self.positional_embedding = PositionalEncoding(hidden_dim, self.max_sequence_length)
        self.speaker_embedding = nn.Linear(speaker_embedding_dim, hidden_dim)
        self.content_embedding = nn.Linear(content_embedding_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        
        self.stop_token_predictor = nn.Sequential(
            nn.LayerNorm(self.content_embedding_dim),
            Rearrange('B S E -> B E S'),
            nn.Conv1d(self.content_embedding_dim, self.content_embedding_dim, 9, padding = (9 - 1) // 2),
            nn.GELU(),
            nn.Dropout(duration_dropout),
            nn.Conv1d(self.content_embedding_dim, 1, 1, padding = (1 - 1) // 2),
            Rearrange('B 1 S-> B S')
        )
        
        self.layers = nn.Sequential(*[TransformerDecoderBlock(hidden_dim,
                                                              num_heads,
                                                              dropout,
                                                              expansion_factor)  for _ in range(num_layers)])
        
        self.hidden_2_wavlm = nn.Linear(self.hidden_dim, self.content_embedding_dim)
        
        
    def forward(self, source_content, source_lengths, reference_embedding):
        src_padding_masks, src_fill_masks = self.get_masks(source_lengths, self.max_sequence_length, self.hidden_dim)
         
        ref_embedding = self.speaker_embedding(reference_embedding)
        ref_embedding = repeat(ref_embedding, 'B 1 E -> B S E', S=self.max_sequence_length)
        
        x = self.positional_embedding(self.content_embedding(source_content)) + ref_embedding 
        memory = self.positional_embedding(self.content_embedding(source_content)) + ref_embedding 
        
        for idx, layer in enumerate(self.layers):
            x = layer(query=x,
                      memory=memory,
                      query_padding_mask= src_padding_masks if idx < 1 else None,
                      query_fill_mask= src_fill_masks if idx < 1 else None,
                      memory_padding_mask= src_padding_masks,
                      memory_fill_mask= src_fill_masks
                     )
        
        x = self.hidden_2_wavlm(x)
        stop_token_out = self.stop_token_predictor(x)
        return x, stop_token_out
    
    @staticmethod
    def get_masks(seq_lens, max_seq_len, embed_dim):
        B = seq_lens.shape[0]
        masks = [[ mask_idx >= seq_lens[seq_idx]  for mask_idx in torch.arange(max_seq_len)]  for seq_idx in torch.arange(0, B)]
        masks = torch.tensor(masks, dtype=torch.bool)
        fill_masks = repeat(masks.T, 'b s -> e b s', e=embed_dim).T.contiguous()
        
        return masks.to(device), fill_masks.to(device)
    
    @staticmethod
    def get_causal_masks(max_seq_len):
        return ~torch.tril(torch.ones((max_seq_len,max_seq_len), dtype=torch.bool, device=device))
    
# dec = WavLMDecoder(speaker_embedding_dim=512, content_embedding_dim=1024, hidden_dim=256, num_heads=2, num_layers=3, max_sequence_length=351)
# dec(source_content = torch.randn(3, 351, 1024), source_lengths = torch.tensor([100, 150, 200]), reference_embedding = torch.randn(3, 1, 512))

In [None]:
class WavLMVCLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_loss = nn.L1Loss()
        self.stop_token_loss = nn.CrossEntropyLoss()
        
    def forward(self,
                feature_predictions,
                feature_targets,
                stop_token_predictions,
                stop_token_targets,
                feature_masks):
        
        feature_masks = ~feature_masks
        feature_masks.requires_grad = False
        feature_targets.requires_grad = False
        stop_token_targets.requires_grad = False
        
        batch_size = feature_predictions.shape[0]
        feature_masks = repeat(feature_masks, 'b s -> b s e', e=feature_predictions.shape[-1])
        masked_feature_predictions = feature_predictions.masked_select(feature_masks).float()
        masked_feature_targets = feature_targets.masked_select(feature_masks).float()
        
        feature_loss = self.feature_loss(masked_feature_predictions, masked_feature_targets)
        stop_token_loss = self.stop_token_loss(stop_token_predictions, stop_token_targets)
        total_loss = feature_loss + stop_token_loss
        
        return (total_loss, feature_loss, stop_token_loss)   

In [None]:
class WavLMVC(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = WavLMDecoder(**kwargs)
        self.loss_module = WavLMVCLoss()
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), betas=(0.9, 0.98))
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,100], gamma=0.1)
        return [optimizer], [lr_scheduler]
    
    def training_step(self, batch, batch_idx):
        (source_content_embed,
         target_content_embed,
         optim_target_content_embed,
         target_speaker_embed,
         source_length,
         target_length) = batch
        
        (target_predictions, stop_token_predictions) = self.model(source_content=source_content_embed,
                                                                  source_lengths=source_length,   
                                                                  reference_embedding=target_speaker_embed)
        
        trg_padding_masks, _ = WavLMDecoder.get_masks(target_length,
                                                      self.hparams.max_sequence_length,
                                                      self.hparams.hidden_dim)
        
        (total_loss, feature_loss, stop_token_loss) = self.loss_module(feature_predictions=target_predictions,
                                                                       feature_targets=optim_target_content_embed,
                                                                       stop_token_predictions=stop_token_predictions,
                                                                       stop_token_targets=target_length,
                                                                       feature_masks=trg_padding_masks)
        
        self.log(f'train_total_loss', total_loss, on_step=True, on_epoch=True)
        self.log(f'train_feature_loss', feature_loss, on_step=True, on_epoch=True)
        self.log(f'train_stop_loss', stop_token_loss, on_step=True, on_epoch=True)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        (source_content_embed,
         target_content_embed,
         optim_target_content_embed,
         target_speaker_embed,
         source_length,
         target_length) = batch
        
        (target_predictions, stop_token_predictions) = self.model(source_content=source_content_embed,
                                                                  source_lengths=source_length,   
                                                                  reference_embedding=target_speaker_embed)
        
        trg_padding_masks, _ = WavLMDecoder.get_masks(target_length,
                                                      self.hparams.max_sequence_length,
                                                      self.hparams.hidden_dim)
        
        (total_loss, feature_loss, stop_token_loss) = self.loss_module(feature_predictions=target_predictions,
                                                                       feature_targets=optim_target_content_embed,
                                                                       stop_token_predictions=stop_token_predictions,
                                                                       stop_token_targets=target_length,
                                                                       feature_masks=trg_padding_masks)
        
        self.log(f'val_total_loss', total_loss, on_step=True, on_epoch=True)
        self.log(f'val_feature_loss', feature_loss, on_step=True, on_epoch=True)
        self.log(f'val_stop_loss', stop_token_loss, on_step=True, on_epoch=True)
    
    def test_step(self, batch, batch_idx):
        (source_content_embed,
         target_content_embed,
         optim_target_content_embed,
         target_speaker_embed,
         source_length,
         target_length) = batch
        
        (target_predictions, stop_token_predictions) = self.model(source_content=source_content_embed,
                                                                  source_lengths=source_length,   
                                                                  reference_embedding=target_speaker_embed)
        
        trg_padding_masks, _ = WavLMDecoder.get_masks(target_length,
                                                      self.hparams.max_sequence_length,
                                                      self.hparams.hidden_dim)
        
        (total_loss, feature_loss, stop_token_loss) = self.loss_module(feature_predictions=target_predictions,
                                                                       feature_targets=optim_target_content_embed,
                                                                       stop_token_predictions=stop_token_predictions,
                                                                       stop_token_targets=target_length,
                                                                       feature_masks=trg_padding_masks)
        
        self.log(f'test_total_loss', total_loss, on_step=True, on_epoch=True)
        self.log(f'test_feature_loss', feature_loss, on_step=True, on_epoch=True)
        self.log(f'test_stop_loss', stop_token_loss, on_step=True, on_epoch=True)

### Train the model

In [None]:
import wandb
wandb.finish()
project = 'WavLMVC'
model_version = 'v:2.0.2'
wandb_logger = WandbLogger(project=project, name=model_version)

In [None]:
def train_model(**kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join('/kaggle/working/checkpoints', f'WavLM VC - {model_version}'),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=None,
                         max_steps=5000,
                         logger=wandb_logger,
                         gradient_clip_val=1.0,
                         check_val_every_n_epoch=1,
                         overfit_batches=4,
                         num_sanity_val_steps=0,
                         limit_val_batches=4,
#                          callbacks=[
#                              ModelCheckpoint(save_weights_only=False,
#                                              save_last=True,
#                                              mode="min", monitor="val_total_loss",
#                                              save_top_k=3
#                                             ),
                                             
#                             ModelCheckpoint(save_weights_only=False,
#                                      save_last=True,
#                                      mode="max", monitor="step",
#                                      every_n_train_steps =500,
#                                      save_top_k=3
#                                     )]
                        )
    
    model = WavLMVC(**kwargs)
    num_params = sum([p.numel() for p in model.parameters()])
    print(f'Number of params: {num_params}')
    
    trainer.fit(model, train_loader, val_loader)
    model = WavLMVC.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training
    return model, trainer

In [None]:
model, results = train_model(
    speaker_embedding_dim=512,
    content_embedding_dim=1024,
    hidden_dim=512,
    num_heads=2,
    num_layers=4,
    max_sequence_length=MAX_SEQUENCE_LENGTH,
    expansion_factor=2,
    dropout=0.0,
    duration_dropout=0.0)  