## Testing outputs for FACodec

In [33]:
import sys
import librosa
import torch
import soundfile as sf
import IPython.display as IPD
sys.path.append("/home/yash7/Amphion/")

from models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2

In [9]:
enc = FACodecEncoderV2(
    ngf=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
)
dec = FACodecDecoderV2(
    in_channels=256,
    upsample_initial_channel=1024,
    ngf=32,
    up_ratios=[5, 5, 4, 2],
    vq_num_q_c=2,
    vq_num_q_p=1,
    vq_num_q_r=3,
    vq_dim=256,
    codebook_dim=8,
    codebook_size_prosody=10,
    codebook_size_content=10,
    codebook_size_residual=10,
    use_gr_x_timbre=True,
    use_gr_residual_f0=True,
    use_gr_residual_phone=True,
)

  WeightNorm.apply(module, name, dim)


In [10]:
enc.eval()
dec.eval()

FACodecDecoderV2(
  (quantizer): ModuleList(
    (0): ResidualVQ(
      (layers): ModuleList(
        (0): FactorizedVectorQuantize(
          (in_proj): Linear(in_features=256, out_features=8, bias=True)
          (out_proj): Linear(in_features=8, out_features=256, bias=True)
          (_codebook): Embedding(1024, 8)
        )
      )
    )
    (1): ResidualVQ(
      (layers): ModuleList(
        (0-1): 2 x FactorizedVectorQuantize(
          (in_proj): Linear(in_features=256, out_features=8, bias=True)
          (out_proj): Linear(in_features=8, out_features=256, bias=True)
          (_codebook): Embedding(1024, 8)
        )
      )
    )
    (2): ResidualVQ(
      (layers): ModuleList(
        (0-2): 3 x FactorizedVectorQuantize(
          (in_proj): Linear(in_features=256, out_features=8, bias=True)
          (out_proj): Linear(in_features=8, out_features=256, bias=True)
          (_codebook): Embedding(1024, 8)
        )
      )
    )
  )
  (model): Sequential(
    (0): Conv1d(256

In [42]:
def process(audio):
    audio = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0)
    return audio

testWavPath = "/home/yash7/Amphion/custom/data/wTIMIT_Normal_s019/008.wav"
testWavPath2 = "/home/yash7/Amphion/custom/data/wTIMIT_Normal_s000/009.wav"

testWav1, sr1 = librosa.load(testWavPath, sr=16000)
testWav2, sr2 = librosa.load(testWavPath2, sr=16000)

testWav1 = process(testWav1)
testWav2 = process(testWav2)

print("testWav1 shape: ", testWav1.shape)
print("testWav2 shape: ", testWav2.shape)

testWav1 shape:  torch.Size([1, 1, 80904])
testWav2 shape:  torch.Size([1, 1, 52768])


In [44]:
with torch.no_grad():
    encOut1 = enc(testWav1)
    prosodyEnc1 = enc.get_prosody_feature(testWav1)
    print("encOut1 shape:\t\t\t", encOut1.shape)
    print("prosodyEnc1 shape:\t\t", prosodyEnc1.shape)
    vqPostEmb, vqId, _, quantized1, spkEmb1 = dec(encOut1, prosodyEnc1, eval_vq=False, vq=True)
    print("vqPostEmb shape:\t\t", vqPostEmb.shape)
    
    print("VQ ID shape:\t\t\t", vqId.shape)
    
    prosodyCode = vqId[:1]
    print("Prosody code shape:\t\t", prosodyCode.shape)
    
    contentCode = vqId[1:3]
    print("Content code shape:\t\t", contentCode.shape)
    
    residualCode = vqId[3:]
    print("Residual code shape:\t\t", residualCode.shape)
    
    print("Speaker embedding shape:\t", spkEmb1.shape)
    
    reconWav1 = dec.inference(
        x = vqPostEmb,
        speaker_embedding=spkEmb1
    )
    print("Reconstructed wav shape:\t", reconWav1.shape)
    sf.write('recon_s019_008.wav', reconWav1[0][0].cpu().numpy(), 16000)
IPD.display(IPD.Audio(testWav2[0][0], rate=16000))
IPD.display(IPD.Audio(reconWav1[0][0], rate=16000))

encOut1 shape:			 torch.Size([1, 256, 404])
prosodyEnc1 shape:		 torch.Size([1, 20, 404])
vqPostEmb shape:		 torch.Size([1, 256, 404])
VQ ID shape:			 torch.Size([6, 1, 404])
Prosody code shape:		 torch.Size([1, 1, 404])
Content code shape:		 torch.Size([2, 1, 404])
Residual code shape:		 torch.Size([3, 1, 404])
Speaker embedding shape:	 torch.Size([1, 256])
Reconstructed wav shape:	 torch.Size([1, 1, 80800])


|TensorName|Shape1|Shape2|
|---|---|---|
|wav shape:|torch.Size([1, 1, 80800])|torch.Size([1, 1, 60000])|
|encOut1 shape:|torch.Size([1, 256, 404])|torch.Size([1, 256, 300])|
|prosody enc 1 shape:|torch.Size([1, 20, 404])|torch.Size([1, 20, 300])|
|vqPostEmb shape:|torch.Size([1, 256, 404])|torch.Size([1, 256, 300])|
|VQ ID shape:|torch.Size([6, 1, 404])|torch.Size([6, 1, 300])|
|Prosody code shape:|torch.Size([1, 1, 404])|torch.Size([1, 1, 300])|
|Content code shape:|torch.Size([2, 1, 404])|torch.Size([2, 1, 300])|
|Residual code shape:|torch.Size([3, 1, 404])|torch.Size([3, 1, 300])|
|Speaker embedding shape:|torch.Size([1, 256])|torch.Size([1, 256])|
|Reconstructed wav shape:|torch.Size([1, 1, 80800])|torch.Size([1, 1, 60000])|

# Shape Observations:
|Name|Shape|Notes|
|---|---|---|
|Wav shape|(1, 1, N)|N  -> Number of samples|
|EncOut1 shape|(1, 256, N/200)|N/200 -> Number of frames, 256 -> Number of channels, 200 -> Hop Size|
|Prosody enc 1 shape|(1, 20, N/200)|N/200 -> Number of frames, 20 -> Prosody Encoding Dimension|
|vqPostEmb shape|(1, 256, N/200)|
|vqId shape|(6, 1, N/200)|6 -> Number of quantizers (prosody + content + residual)|
|Prosody code shape|(1, 1, N/200)|1 quantizer for prosody|
|Content code shape|(2, 1, N/200)|2 quantizers for content|
|Residual code shape|(3, 1, N/200)|3 quantizers for residual (acoustic details, timbre, etc)|
|Speaker Embedding shape|(1, 256)|256 -> Number of dimensions for speaker embeddings|
|Reconstructed wav shape|(1, 1, N')|Roughly the same as original wav, N' = (N / 200) * 200|

* 200 -> Hop Size
* Sample Rate = 16kHz, Hop Size = 200 => 1 frame corresponds to (200 / 16000 * 1000) = 12.5ms
* Number of quantizers (6 = 1 + 2 + 3) was a design choice. 
* Codebook size (256 here) is a design choice, paper specified 1024
* Note, audio may need to be padded to an appropriate lengths otherwise the scripts may crash

## Extracting Content Units (Ideas):
* Use of data augmentations that do not affect content and using contrastive loss on the the samples. Some possible augmentations might include:
    * Loudness changes
    * Pitch changes
    * Addition of random White noise (maybe trying Red noise)
    * Changing Audio speed
    * Combinations of these
* Use of transcripts, audio samples and audios generated using TTS engines