In [1]:
import sys
import os
import math
import random   
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio

sys.path.append('/home/yutong/Amphion')

from models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
from huggingface_hub import hf_hub_download

# Add the Controlled_AC folder (parent of FACodec_AC) to sys.path so we can import our modules.
# sys.path.append('/home/yurii/Projects/AC/Controlled_AC')

# Import facoder encoder/decoder and the diffusion model from FACodec_AC.
# (Make sure the module names and classes match what is in your repository.)
from FACodec_AC.models import DiffusionTransformerModel
from FACodec_AC.config import Config, ASRConfig
from FACodec_AC.utils import get_zc1_from_indx, get_phone_forced_alignment, interpolate_alignment, pad_token_sequence
from IPython.display import Audio

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

device='cuda'
#CUDA_LAUNCH_BLOCKING=0

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fa_encoder = FACodecEncoder(
    ngf=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256
)
fa_decoder = FACodecDecoder(
    in_channels=256,
    upsample_initial_channel=1024,
    ngf=32,
    up_ratios=[5, 5, 4, 2],
    vq_num_q_c=2,
    vq_num_q_p=1,
    vq_num_q_r=3,
    vq_dim=256,
    codebook_dim=8,
    codebook_size_prosody=10,
    codebook_size_content=10,
    codebook_size_residual=10,
    use_gr_x_timbre=True,
    use_gr_residual_f0=True,
    use_gr_residual_phone=True,
)


encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")

fa_encoder.load_state_dict(torch.load(encoder_ckpt))
fa_decoder.load_state_dict(torch.load(decoder_ckpt))

fa_encoder.eval()
fa_decoder.eval()


if torch.cuda.is_available():
  fa_encoder = fa_encoder.cuda()
  fa_decoder = fa_decoder.cuda()

  WeightNorm.apply(module, name, dim)


In [3]:
# Load an example audio (adjust the file path as needed).
#wav_filepath = '/home/yurii/Projects/AC/ljspeech/LJSpeech-1.1/wavs/LJ001-0009.wav'
# file_name = 'bwc_0001'
#file_name = 'aba_0001'
audio_folder = './'
# wav_filepath = audio_folder + file_name + '.wav'
wav_filepath = 'arctic_b0504.wav'
file_name = 'arctic_b0504'
wav_waveform, wav_sr = torchaudio.load(wav_filepath)
# Resample to 16000 Hz if needed.
if wav_sr != 16000:
    resample = torchaudio.transforms.Resample(orig_freq=wav_sr, new_freq=16000)
    wav_waveform = resample(wav_waveform)
    
wav_waveform = wav_waveform.to(device)

In [4]:
# Run facoder encoder and decoder to obtain latent representations.
with torch.no_grad():
    h_input = fa_encoder(wav_waveform[None, :, :])
    vq_post_emb, vq_id, _, quantized_arr, spk_embs = fa_decoder(h_input, eval_vq=False, vq=True)

zc1_indx = vq_id[1]

In [5]:
from transformers import (
	Wav2Vec2FeatureExtractor,
	Wav2Vec2CTCTokenizer,
	Wav2Vec2Processor,
	Wav2Vec2ForCTC,
	AutoTokenizer,
	T5ForConditionalGeneration,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1) Load phoneme‐CTC Wav2Vec2
fe  = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
tok = Wav2Vec2CTCTokenizer .from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft", use_fast=False)
proc= Wav2Vec2Processor   (feature_extractor=fe, tokenizer=tok)
model= Wav2Vec2ForCTC     .from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft").to(device)
model.eval()
target_sr = fe.sampling_rate # 16000

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2PhonemeCTCTokenizer'. 
The class this function is called from is 'Wav2Vec2CTCTokenizer'.


In [6]:
transcript_dict = {file_name: 'The life there was helthiful and athletic but too juvenile'}
# transcript_dict = {wav_filepath: 'author of the danger trail philip steeles and etcetera'}

In [7]:
from nemo_text_processing.text_normalization.normalize import Normalizer
import re
pipeline = {
    "normaliser": Normalizer(lang="en", input_case="cased", deterministic=True, post_process=True),
    "regex": re.compile(r"[^a-z' ]"),
    "LANG_TAG": "<eng-us>: ",
    "dev_g2p": "cuda" if torch.cuda.is_available() else "cpu",
    "g2p_tok": AutoTokenizer.from_pretrained("google/byt5-small"),
    "g2p_net": T5ForConditionalGeneration.from_pretrained("charsiu/g2p_multilingual_byT5_tiny_16_layers_100").to(device).eval()
}

preidcted_ids, _, frames_score, _ = get_phone_forced_alignment(embedding_path=file_name+'.pt', 
                                 audio_folder=audio_folder, 
                                 transcript_metadata=transcript_dict, 
                                 device=device, 
                                 model=model, 
                                 proc=proc, 
                                 target_sr=target_sr,
                                 pipeline=pipeline, 
                                 inference=True)
num_zeros = len(zc1_indx[0])
interpolated_phone_ids = interpolate_alignment(preidcted_ids, num_zeros)
interpolated_predicted_ids, pad_mask = pad_token_sequence(interpolated_phone_ids, Config.max_seq_len, ASRConfig.PAD_ID)

# Remove padding
interpolated_predicted_ids = interpolated_predicted_ids[~pad_mask]

 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.


In [8]:
from torchaudio.functional import merge_tokens 

spans = merge_tokens(preidcted_ids[0].cpu(), frames_score[0].cpu()) 
token_conf = torch.sigmoid(torch.tensor([s.score for s in spans]))


In [9]:
import torch.nn.functional as F


readable_spans = []
for span in spans:
    # 1) decode the token ID to phoneme symbol
    token = proc.tokenizer.convert_ids_to_tokens(span.token)  # <- better

    # 2) turn logit score into probability
    prob = float(F.sigmoid(torch.tensor(span.score)))

    readable_spans.append({
        "token": token,
        "start_frame": span.start,
        "end_frame": span.end,
        "probability": prob
    })

# Optional: print them nicely
for span in readable_spans:
    print(f"{span['token']:>4}  frames {span['start_frame']:>3}-{span['end_frame']:<3}  prob={span['probability']:.2f}")


  tʰ  frames   7-8    prob=0.00
  eɪ  frames   9-10   prob=0.00
   l  frames  16-17   prob=0.50
   ɪ  frames  19-20   prob=0.00
   f  frames  34-35   prob=0.49
   ɛ  frames  43-44   prob=0.00
   t  frames  48-49   prob=0.01
   ɛ  frames  51-52   prob=0.42
   r  frames  56-57   prob=0.48
   ɛ  frames  57-58   prob=0.00
   w  frames  60-61   prob=0.43
   a  frames  63-64   prob=0.07
   s  frames  73-74   prob=0.03
   h  frames  78-80   prob=0.48
   ɛ  frames  82-83   prob=0.44
   l  frames  86-87   prob=0.49
   θ  frames  91-92   prob=0.01
   ɪ  frames  93-94   prob=0.07
   f  frames 101-102  prob=0.50
   ʊ  frames 103-104  prob=0.19
   l  frames 108-110  prob=0.44
   a  frames 119-120  prob=0.04
   n  frames 123-124  prob=0.49
   d  frames 124-125  prob=0.20
   a  frames 130-131  prob=0.08
   θ  frames 135-136  prob=0.00
   l  frames 140-142  prob=0.48
   ɛ  frames 142-143  prob=0.16
   t  frames 148-149  prob=0.49
   i  frames 149-150  prob=0.26
   k  frames 156-157  prob=0.49
   b  fr

In [10]:
import itertools

# Get the pad token id.
pad_token_id = proc.tokenizer.pad_token_id

# --- Clean transcript ---
# Convert to list (squeezing if needed) and filter out pad tokens.
token_ids = interpolated_predicted_ids.squeeze().tolist()
filtered_ids = [tid for tid in token_ids if tid != pad_token_id]

# Remove consecutive duplicates.
clean_ids = [key for key, _ in itertools.groupby(filtered_ids)]

# Decode the cleaned token ids.
decoded_text = proc.tokenizer.decode(clean_ids, skip_special_tokens=True)
print("Clean transcript:")
print(decoded_text)

# --- Transcript with repetitions and '-' placeholder for pad tokens ---
# Use the original token_ids list (with pads and repetitions preserved).
tokens = proc.tokenizer.convert_ids_to_tokens(token_ids)
# Replace pad token with '-' and keep the other tokens as is.
tokens_with_dash = ['-' if tid == pad_token_id else t for tid, t in zip(token_ids, tokens)]
# Join tokens with spaces.
decoded_text_with_dash = ' '.join(tokens_with_dash)
print("\nTranscript with '-' for pads:")
print(decoded_text_with_dash)

Clean transcript:
tʰeɪlɪfɛtɛrɛwashɛlθɪfʊlandaθlɛtikbut

Transcript with '-' for pads:
- - - - - - - - - - - - tʰ - - eɪ eɪ - - - - - - - - - l l - - - ɪ ɪ - - - - - - - - - - - - - - - - - - - - - - f f - - - - - - - - - - - - - ɛ - - - - - - - t - - - ɛ ɛ - - - - - - - r ɛ ɛ - - - w w - - - a - - - - - - - - - - - - - - - s - - - - - - - h h h - - - ɛ ɛ - - - - - l - - - - - - - θ - - ɪ ɪ - - - - - - - - - - - f - - ʊ ʊ - - - - - - l l l - - - - - - - - - - - - - - - a - - - - - n n d - - - - - - - - a a - - - - - - - θ - - - - - - - l l l ɛ - - - - - - - - t t i i - - - - - - - - - k k - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - b b - - - - - - - - - - - - - - - - - - - - - - - - - - - u u - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - t - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -


In [11]:
pretrained_codebook = fa_decoder.quantizer[1].layers[0].codebook
pretrained_proj_layer = fa_decoder.quantizer[1].layers[0].out_proj

diffusion_model = DiffusionTransformerModel(
    pretrained_codebook=pretrained_codebook,
    pretrained_proj_layer=pretrained_proj_layer,
    std_file_path=os.path.join(Config.zc1_data_dir, 'stats', 'std.pt'),
    vocab_size=Config.VOCAB_SIZE,
    d_model=Config.d_model,
    nhead=Config.nhead,
    num_layers=Config.num_layers,
    d_ff=Config.d_ff,
    dropout=Config.dropout,
    max_seq_len=Config.max_seq_len
)
# Load the diffusion model checkpoint.
diff_ckpt_path = 'logs/checkpoints/model_exp_888.pt'
diffusion_model.load_state_dict(torch.load(diff_ckpt_path, map_location=device))
diffusion_model = diffusion_model.to(device)
diffusion_model.eval()
_

tensor([[[ 12.6560, -11.3779, -11.5954,  ..., -11.1314, -11.1149, -11.3886],
         [ 12.9684, -11.4895, -11.7963,  ..., -11.2638, -11.2872, -11.5055],
         [ 13.1735, -11.6445, -11.9602,  ..., -11.4312, -11.4182, -11.5752],
         ...,
         [ 11.5143, -10.2931, -10.5837,  ..., -10.2132, -10.2212, -10.2716],
         [ 12.8190, -11.3654, -11.7273,  ..., -11.2228, -11.2559, -11.4240],
         [ 13.0255, -11.5682, -11.9646,  ..., -11.4370, -11.3570, -11.6600]]],
       device='cuda:0')

In [None]:
bsz, seq_len = zc1_indx.shape
start, end = 0, seq_len

# create a padding mask: 0 indicates valid tokens, 1 indicates pad tokens
padding_mask = torch.zeros_like(zc1_indx, dtype=torch.bool, device=device)

# Create a mask_positions tensor marking positions from start to end
mask_positions = torch.zeros_like(zc1_indx, dtype=torch.bool, device=device)
mask_positions[:, start:end] = True


# Compute normalized t as fraction of tokens masked (no padding involved)
masked_count = mask_positions.float().sum(dim=1).item()  # [bsz]
t_value = (masked_count / seq_len)
t_value_norm = t_value / 1
t = torch.full((bsz,1), t_value_norm, device=device, dtype=torch.float)

noise_level_value = 20
noise_min = 1.0
noise_max = 35.0

noise_level_value_norm = (noise_level_value - noise_min) \
        / (noise_max - noise_min)
noise_level = torch.full((bsz,1), noise_level_value_norm, device=device, dtype=torch.float)

# Create noise_scaled
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
    
noise_scaled = torch.randn(
    bsz, seq_len, diffusion_model.proj_to_256.out_features, device=device
) * (noise_level_value * diffusion_model.precomputed_std)


noise_max 35.0


In [13]:
prosody_cond = quantized_arr[0]

In [14]:
diffusion_model.eval()
with torch.no_grad():
    logits = diffusion_model(
        x=zc1_indx.to(device),
        padded_phone_ids=interpolated_predicted_ids.to(device).unsqueeze(0),
        prosody_cond=prosody_cond,
        noise_level=noise_level,
        noise_scaled=noise_scaled,
        mask_positions=mask_positions,
        padding_mask=padding_mask
    )
    # For example, take argmax to get predicted tokens.
    predicted_zc1_indx= logits.argmax(dim=-1)

In [15]:
correct, total = (predicted_zc1_indx == zc1_indx).sum().item(), len(zc1_indx[0])
print(f"{correct} out of {total}")

28 out of 389


In [16]:
zc1_original = get_zc1_from_indx(zc1_indx, padding_mask, fa_decoder)
zc1_corrupted = get_zc1_from_indx(zc1_indx, padding_mask, fa_decoder)
# Corrupt zc1
zc1_corrupted[:,:,start:end]+=noise_scaled.reshape(1,256,-1)[:,:,start:end]
zc1_reconstructed = get_zc1_from_indx(predicted_zc1_indx, padding_mask, fa_decoder)

In [17]:
wavs_output = []
for z_c_tmp in [zc1_original, zc1_corrupted, zc1_reconstructed]:
    with torch.no_grad():
        combined_outs = quantized_arr[0] + z_c_tmp + quantized_arr[2]
        #combined_outs = quantized_arr[0]
        wav_reconstructed = fa_decoder.inference(combined_outs, spk_embs)
        wavs_output.append(wav_reconstructed)

In [18]:
for i, label in enumerate(["Original", "Corrupted", "Reconstructed"]):
    print(label)
    waveform = wavs_output[i].detach().squeeze().cpu().numpy()
    display(Audio(waveform, rate=16000))

Original


Corrupted


Reconstructed
