# Import the relevant libraries

In [None]:
import pandas as pd
import re
from creapy import creapy
import plotly
from pathlib import Path
import sys
import logging
import os
import time
import librosa
import numpy as np
import torch
from scipy.io.wavfile import write
import utils
from models import SynthesizerTrn
from speaker_encoder.voice_encoder import SpeakerEncoder
from wavlm import WavLM, WavLMConfig
from datetime import datetime
import IPython.display as ipd 
import json
import soundfile as sf
from praatio import textgrid as tg
from praatio.utilities.constants import Interval
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



## Load the models and helper functions
This includes setting location of the downloaded checkpoint off WavLM and the FreeVC model to be used

In [None]:
#wavlm_large_path = 'wavlm/WavLM-Large.pt'
wavlm_large_path = '/home/lameris/CreakVC/wavlm/WavLM-Large.pt'
freevc_chpt_path = 'logs/libri_train_only/VQVC.pth'

# --- Argument creator ---
def arg_creator(source, target, outpath, creak, cpps, h1h2, pitch, h1a3, pitch_var):
    return [source, target, outpath, creak, cpps, h1h2, pitch, h1a3, pitch_var]

# --- Load WavLM model ---
def get_cmodel():
    checkpoint = torch.load(wavlm_large_path)
    cfg = WavLMConfig(checkpoint['cfg'])
    cmodel = WavLM(cfg)
    cmodel.load_state_dict(checkpoint['model'])
    cmodel.eval()
    return cmodel.to(device)

# --- Convert to tensor ---
def to_tensor(x, audio_len, device='cpu'):
    if isinstance(x, (int, float)):
        return torch.full((audio_len,), x, dtype=torch.float32, device=device)
    return x.to(device)

# --- Generic voice quality adjustment ---
def apply_vq(args, audio_len, scaling_factor, time_range, target_vals, sr=16000, device=device):
    if scaling_factor == 0:
        return args

    start_frame, end_frame = (0, audio_len) if time_range == (0, -1) else (
        int(time_range[0] * sr // 320),
        min(int(time_range[1] * sr // 320), audio_len)
    )

    creak, cpps, h1h2, h1a3 = map(lambda x: to_tensor(x, audio_len, device), [args[3], args[4], args[5], args[7]])
    for param, target in zip([creak, cpps, h1h2, h1a3], target_vals):
        param[start_frame:end_frame] += (target[start_frame:end_frame] - param[start_frame:end_frame]) * scaling_factor

    return [args[0], args[1], args[2], creak, cpps, h1h2, args[6], h1a3, args[8]]

# --- Combined VQ workflow ---
def combined_vq(current_args, pitch, pitch_var, audio, sr=16000, device=device):
    tensor = utils.get_content(cmodel, torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device))
    tensor_size = tensor.shape[-1] + 1

    def get_segments(vq_name, default_intensity=1):
        ranges_input = input(f"Enter ranges for {vq_name} (e.g., (0,0.4),(0.5,1.9)) or press Enter for full clip: ").strip()

        # If user presses Enter, assume full clip
        if not ranges_input:
            parsed_ranges = [(0, -1)]
        else:
            time_ranges = re.findall(r'\(\s*([\d.-]+)\s*,\s*([\d.-]+)\s*\)', ranges_input)
            parsed_ranges = [(float(start), float(end)) for start, end in time_ranges]

        segments = []
        for time_range in parsed_ranges:
            while True:
                intensity_input = input(f"Select intensity for {vq_name} {time_range} (0-5, default {default_intensity}): ").strip()
                if not intensity_input:
                    intensity = default_intensity
                    break
                try:
                    intensity = float(intensity_input)
                    if 0 <= intensity <= 5:
                        break
                    else:
                        print("Intensity must be between 0 and 5.")
                except ValueError:
                    print("Please enter a number between 0 and 5.")
            segments.append((time_range, intensity))

        return segments

    breathy_segments = get_segments("breathiness")
    creaky_segments = get_segments("creakiness")
    nasal_segments = get_segments("nasality")

    # Define target values
    targets = {
        "breathy": [torch.full((tensor_size,), -2., device=device),
                    torch.full((tensor_size,), -1., device=device),
                    torch.full((tensor_size,), 3., device=device),
                    torch.full((tensor_size,), 3., device=device)],
        "creaky": [torch.full((tensor_size,), 3., device=device),
                   torch.full((tensor_size,), -1., device=device),
                   torch.full((tensor_size,), -2., device=device),
                   torch.full((tensor_size,), -2., device=device)],
        "nasal": [torch.full((tensor_size,), 0., device=device),
                  torch.full((tensor_size,), 1., device=device),
                  torch.full((tensor_size,), -3., device=device),
                  torch.full((tensor_size,), 3., device=device)]
    }

    for segments, vq_name in [(nasal_segments, "nasal"), (breathy_segments, "breathy"), (creaky_segments, "creaky")]:
        for time_range, intensity in segments:
            if intensity > 0:
                current_args = apply_vq(current_args, tensor_size, intensity, time_range, targets[vq_name], sr, device)

    # Update pitch & variance
    current_args[6] = pitch
    current_args[8] = pitch_var

    # Print averages
    stats = ["creak", "cpps", "h1h2", "pitch", "h1a3", "pitch_var"]
    for i, stat in enumerate(stats, start=3):
        val = current_args[i]
        print(f"Current average {stat}: {torch.mean(val).item() if isinstance(val, torch.Tensor) else val}")

    convert(current_args)
    return current_args

# --- Reset args ---
def reset_args(args):
    return [args[0], args[1], args[2]] + [0]*6

# --- Timestamp ---
def generate_timestamp():
    now = datetime.now()
    return now.strftime("%Y%m%d_%H%M%S_") + str(int(now.microsecond / 1000)).zfill(2)

# --- Convert ---
def convert(args):
    print("Converting...")
    wav_tgt, _ = librosa.load(args[1], sr=hps.data.sampling_rate)
    wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
    g_tgt = torch.from_numpy(smodel.embed_utterance(wav_tgt)).unsqueeze(0).to(device)

    wav_src, _ = librosa.load(args[0], sr=hps.data.sampling_rate, mono=True)
    wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
    c = utils.get_content(cmodel, wav_src)

    for dim in [c.shape[-1], c.shape[-1]-1]:
        basic_tensor = torch.zeros((1,1,dim), device=device)
        try:
            tgt_audio = net_g.infer(c, g=g_tgt,
                                    creaks=basic_tensor+args[3],
                                    cpps=basic_tensor+args[4],
                                    h1h2s=basic_tensor+args[5],
                                    pitches=basic_tensor+args[6],
                                    h1a3s=basic_tensor+args[7],
                                    pitch_vars=basic_tensor+args[8])
            break
        except Exception as e:
            continue

    tgt_audio = tgt_audio[0][0].cpu().detach().numpy()
    timestamp = generate_timestamp()
    print(timestamp)
    ipd.display(ipd.Audio(tgt_audio, rate=hps.data.sampling_rate))
    write(args[2], hps.data.sampling_rate, tgt_audio)
    return tgt_audio

# --- Create VQ tensor from segments ---
def create_vq_tensor(wav_len, vq_segments, sr):
    end_frame = wav_len // 320
    vqs = torch.zeros((1,1,end_frame), dtype=torch.float32)
    for start, end, val in [(int(s*sr//320), int(e*sr//320), v) for s,e,v in vq_segments]:
        vqs[0,0,start:end] = val
    return vqs

# --- JSON to TextGrid ---
def json_to_textgrid(json_path, textgrid_path):
    with open(json_path) as f:
        data = json.load(f)
    start_speech, end_speech = data['segments'][0]['start'], data['segments'][-1]['end']
    timestamps = [word for seg in data['segments'] for word in seg['words']]

    tg_obj = tg.Textgrid()
    tier = tg.IntervalTier('word', [], start_speech, end_speech)
    for i, t in enumerate(timestamps):
        start = timestamps[i-1]['end']+0.01 if i>0 else t['start']
        end = t['end']
        tier.insertEntry(Interval(start, end, t['word']))
    tg_obj.addTier(tier)
    tg_obj.save(textgrid_path, format='long_textgrid', includeBlankSpaces=False)
    return timestamps

# --- Initialize ---
hps = utils.get_hparams_from_file('configs/freevc.json')
net_g = SynthesizerTrn(hps.data.filter_length//2+1, hps.train.segment_size//hps.data.hop_length, **hps.model)
utils.load_checkpoint(freevc_chpt_path, net_g, optimizer=None, strict=True)
cmodel = get_cmodel()
smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt', device=device)


## Set the arguments
Here we set the following arguments for the creaky voice conversion:

1. **source_path** indicates the audio file of which we want the linguistic content.
2. **target_path** indicates the audio file containing speech of the target speaker.
3. **outpath** is the location where the converted audio will be saved.
4. **average_feature** is the initial feature value over the complete utterance that will be supplied to the model 

We also create the output folder specified in the arguments


## Instructions
1. Select the source path that you want to convert by uncommenting it. 
2. Select your desired feature values. I recommend starting with all zeros, except for pitch where I recommend starting at -1.
3. Listen to the original audio in the cell below the manipulation cell.
4. Perform the conversion.
5. Decide if you like the prosody.
6. If not, change the feature values in the coarse-grained editing cell.



## Manipulation cell

In [None]:
source_path = "/shared/lameris/spoken_stereoset/speech/gender/02273_en-US-ChristopherNeural.wav"


target_path = "/home/lameris/CreakVC/speech_continuation/concatenated_christopher_44s.wav"
outpath = f"./data/experiment/{source_path.split('/')[-1]}"

average_creak = 0
average_cpps = 0
average_h1h2 = 0
average_pitch = -1.5
average_h1a3 = 0
average_pitch_var = 0

sr=16000
args = arg_creator(source_path, target_path, outpath, average_creak, average_cpps, average_h1h2, average_pitch, average_h1a3, average_pitch_var)
os.makedirs('/'.join(args[2].split('/')[:-1]), exist_ok=True)



src_audio = sf.read(source_path)[0]
ipd.display(ipd.Audio(src_audio, rate=16000))


# Perform the conversion

In [None]:
converted_audio = convert(args)

## Coarse-grained editing for finding pitch and pitch variation


In [None]:
average_creak = 0
average_cpps = 0
average_h1h2 = 0
average_h1a3 = 0
average_pitch = -1.5
average_pitch_var = 3

orig_args = arg_creator(source_path, target_path, outpath, average_creak, average_cpps, average_h1h2, average_pitch, average_h1a3, average_pitch_var)
converted_audio = convert(orig_args)

## Run WhisperX
* We run speech recognition to transcribe the generated utterance
* We convert the timestamps from json to TextGrid in order to use prepare them for CreaPy 

In [None]:
#run whisper
import json
import os
from praatio import textgrid as tg
from praatio.utilities.constants import Interval

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

!whisperx "/shared/lameris/spoken_stereoset/speech/gender/02273_en-US-ChristopherNeural.wav" --model distil-medium.en --output_dir data/experiment --language en 

#read json file with the start and end times of the words
json_path = "data/experiment/" + args[0].split('/')[-1].replace(".wav", ".json")
tg_path = "data/experiment/" + source_path.split('/')[-1][:-4] + '.TextGrid'
timestamps = json_to_textgrid(json_path, tg_path)


## Run CreapPy
This enables us to quantify and visualize the creak probability

In [None]:
X_test, y_pred, sr = creapy.process_file(textgrid_path=tg_path, audio_path=f"./data/experiment/{source_path.split('/')[-1]}")
X_test['h1h2'] = np.nan_to_num(X_test['h1h2'])
X_test['h1h2'] = np.convolve(X_test['h1h2'], np.ones(20)/20, mode='same')
y_pred_smoothed = np.convolve(y_pred, np.ones(20)/20, mode='same')

## Plot the audio
Get the durations from here for the fine-grained editing

In [None]:
fig = creapy.plot(X_test, y_pred_smoothed, sr, words=timestamps)

## Fine-grained editing

Edit the features according to the following syntax:


You only need to change the times and feature values!

In [None]:
#combined_vq needs: args, pitch, pitch_var, audio

combined = combined_vq(args, -1.5, 2, converted_audio)


In [None]:
combined = combined_vq(args, -1.5, 0, converted_audio)


In [None]:
average_creak = 0
average_cpps = 0
average_h1h2 = 0
average_h1a3 = 0
average_pitch = -1.5
average_pitch_var = 3

orig_args = arg_creator(source_path, target_path, outpath, average_creak, average_cpps, average_h1h2, average_pitch, average_h1a3, average_pitch_var)
converted_audio = convert(orig_args)

In [None]:
breathy = [source_path, target_path, outpath, -1, -1, 2, -1.5, 2, 1]
creaky = [source_path, target_path, outpath, 2, -.5, -1, -2, -1, -2]
high_pitch_var = [source_path, target_path, outpath, -3, 0, -.5, -1.5, -1, 4]
tense = [source_path, target_path, outpath, -1, 1, 2, -1.5, -2, 0]
nasal = [source_path, target_path, outpath, -1, 1, -3, -1.5, 3, 0]

print('--------------Breathy----------------')
converted_breathy_audio = convert(breathy)
print('--------------Creaky----------------')
converted_creaky_audio = convert(creaky)
print('--------- Pitch-variation----------------')
converted_high_pitch_var_audio = convert(high_pitch_var)
print('--------------Tense----------------')
converted_tense_audio = convert(tense)
print('--------------Nasal----------------')
converted_nasal_audio = convert(nasal)
