# MKL-VC Demo

## Setup the environment

To set up the environment and download the models follow the guidelines from [KNN-VC project](https://github.com/bshall/knn-vc):


1. Create conda/venv environmtent with torch, torchaudio, and numpy dependecies
```    
    $ conda create --name mkl_vc python=3.10
    
    $ conda activate mkl_vc
```

2. Check that pip is called from the created env. You may want to use a full path to pip (like ``/home/user/miniconda3/envs/mkl_vc/bin/pip``) to install dependecies
```
    $ which pip

    $ pip install torch==2.4.0 torchaudio==2.4.0 numpy==2.0.2 scipy==1.14.1
```

3. Clone the repo
```
    $ git clone https://github.com/alobashev/mkl-vc.git

    $ cd mkl-vc
```
5. Download [WavLM Large](https://github.com/microsoft/unilm/tree/master/wavlm) and place it into ``mkl-vc/models`` folder as ``WavLM-Large.pt``

6. Download ["kNN-VC with prematched HiFiGAN"](https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt) and ["kNN-VC with regular HiFiGAN"](https://github.com/bshall/knn-vc/releases/download/v0.1/g_02500000.pt) and place them into the ``mkl-vc/models`` folder. 

7. The resulting project structure should look like this
```
├── hifigan                                 # adapted hifigan code to vocode wavlm features
│   ├── config_v1_wavlm.json                # hifigan config for use with wavlm features
│   ├── meldataset.py                       # mel-spectrogram transform used during hifigan training
│   ├── models.py                           # hifigan model definition
│   ├── train.py                            # hifigan training script
│   └── utils.py                            # utilities used for hifigan inference/training
├── models
│   ├── g_02500000.pt                       # original HiFiGAN checkpoint
│   ├── prematch_g_02500000.pt              # prematched HiFiGAN from the KNN-VC project
│   └── WavLM-Large.pt
├── wavlm
    ├── modules.py                          # wavlm helper functions (from original WavLM repo)
    └── WavLM.py                            # wavlm modules (from original WavLM repo)
├── matcher.py                              # model wrapper for KNeighborsVC pipeline
├── main.py
├── mkl_vc_demo.ipynb
└── README.md                               
```

## Define methods

In [1]:
import json
import numpy as np
import torch, torchaudio
import scipy.io.wavfile as wavf

from wavlm.WavLM import WavLM, WavLMConfig
from hifigan.models import Generator as HiFiGAN
from hifigan.utils import AttrDict
from matcher import KNeighborsVC


# Dimension of WavLM embeddings
EMBED_LEN = 1024


def MKL(A, B):
    EPS = 2.2204e-16
    
    Da2, Ua = np.linalg.eig(A)
    Da2 = np.diag(Da2)
    Da2[Da2 < 0] = 0
    Da = np.sqrt(Da2 + EPS)
    C = Da @ np.transpose(Ua) @ B @ Ua @ Da
    
    Dc2, Uc = np.linalg.eig(C)
    Dc2 = np.diag(Dc2)
    Dc2[Dc2 < 0] = 0
    Dc = np.sqrt(Dc2 + EPS)
    Da_inv = np.diag(1 / (np.diag(Da)))
    T = Ua @ Da_inv @ Uc @ Dc @ np.transpose(Uc) @ Da_inv @ np.transpose(Ua)
    return T


def apply_mkl(X0, X1):
    A = np.cov(X0, rowvar=False)
    B = np.cov(X1, rowvar=False)
    T = MKL(A, B)
    mX0 = np.mean(X0, axis=0)
    mX1 = np.mean(X1, axis=0)
    XR = (X0 - mX0) @ T + mX1
    XR = np.real(XR)
    return XR


def apply_mkl_batched(X0, X1, batch_size):
    XR = np.zeros_like(X0)
    for i in range(0, EMBED_LEN, batch_size):
        if i + batch_size < EMBED_LEN:
            XR[:,i:i+batch_size] = apply_mkl(X0[:,i:i+batch_size], X1[:,i:i+batch_size])
        elif i < EMBED_LEN - 1:
            XR[:,i:] = apply_mkl(X0[:,i:], X1[:,i:])
        elif i == EMBED_LEN - 1:
            XR[:,i] = X0[:,i]
    return XR


def build_knn_vc_model(device):
    device = torch.device(device)

    with open('./hifigan/config_v1_wavlm.json') as f:
        data = f.read()
    json_config = json.loads(data)
    h = AttrDict(json_config)

    generator = HiFiGAN(h).to(device)
    generator.load_state_dict(torch.load('./models/prematch_g_02500000.pt', weights_only=False)['generator'])
    generator.eval()
    generator.remove_weight_norm()
    print(f"[HiFiGAN] Generator loaded with {sum([p.numel() for p in generator.parameters()]):,d} parameters.")

    checkpoint = torch.load('./models/WavLM-Large.pt', weights_only=False)
    cfg = WavLMConfig(checkpoint['cfg'])
    wavlm = WavLM(cfg)
    wavlm.load_state_dict(checkpoint['model'])
    wavlm = wavlm.to(device)
    wavlm.eval()
    print(f"WavLM-Large loaded with {sum([p.numel() for p in wavlm.parameters()]):,d} parameters.")

    hifigan = generator
    hifigan_cfg = h
    return KNeighborsVC(wavlm, hifigan, hifigan_cfg, device)

## Build the pipeline

In [None]:
device = "cuda:0"
knn_vc = build_knn_vc_model(device)

## Run inference

In [3]:
def run_inference(
    src_wav_path, 
    ref_wav_path, 
    result_path,
    batch_size,     # batch_size is MKL factorized dimension K: from 2 to ~256
):
    with torch.inference_mode():
        
        query_seq = knn_vc.get_features(src_wav_path).to(device)
        matching_set = knn_vc.get_features(ref_wav_path).to(device)
        
        idxes = torch.argsort(query_seq.std(0), descending=True)
        X0 = query_seq[:,idxes].cpu().numpy()
        X1 = matching_set[:,idxes].cpu().cpu().numpy()
    
        # adjust batch_size to avoid artifacts
        # if X1 sequence is too short
        batch_size = min(len(X1)-4, batch_size)
        XR = apply_mkl_batched(X0, X1, batch_size)
    
        query_seq[:,idxes] = torch.Tensor(XR).to(device)
        factorized_mkl_wav = knn_vc.vocode(query_seq[None].to(device)).cpu().squeeze()
        wavf.write(result_path, 16000, factorized_mkl_wav.cpu().numpy())

## Check the result

In [4]:
import IPython.display as ipd

# Example 1: Common Voice Corpus 17.0 
src_path = "examples/common_voice_it_23927476.wav"
ref_path = "examples/common_voice_kk_27684867.wav"
result_path = "examples/mkl_vc_demo_result.wav"

run_inference(src_path, ref_path, result_path, batch_size=2)

In [5]:
print("source")
ipd.Audio(src_path)

source


In [6]:
print("reference")
ipd.Audio(ref_path)

reference


In [7]:
print("result")
ipd.Audio(result_path, rate=16000)

result


In [8]:
# Example 2: The Expresso Dataset
src_path = "examples/ex01_happy_00021_16k.wav"
ref_path = "examples/ex04_happy_00369_16k.wav"
result_path = "examples/mkl_vc_demo_result.wav"

run_inference(src_path, ref_path, result_path, batch_size=2)

In [9]:
ipd.Audio(src_path)

In [10]:
ipd.Audio(ref_path)

In [11]:
ipd.Audio(result_path, rate=16000)