In [1]:
from semanticodec import SemantiCodec
import torch
from torch import nn

semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps

  from .autonotebook import tqdm as notebook_tqdm


🚀 Loading SemantiCodec encoder
✅ Encoder loaded
🚀 Loading SemantiCodec decoder
DiffusionWrapper has 75.04 M params.




✅ Decoder loaded


In [2]:
filepath = "/mnt/users/hccl.local/jkzhao/projects/s3prl/test.wav" # audio with arbitrary length

tokens = semanticodec.encode(filepath)
waveform = semanticodec.decode(tokens)

# Save the reconstruction file
import soundfile as sf
sf.write("output.wav", waveform[0,0], 16000)

DDIM Sampler: 100%|██████████| 50/50 [00:06<00:00,  7.76it/s]


In [None]:
mel = torch.rand(16, 2048, 128).to("cuda")
tokens = semanticodec.encoder(mel)

In [3]:
from semanticodec.utils import extract_kaldi_fbank_feature
# Constants
SAMPLE_RATE = 16000
SEGMENT_DURATION = 10.24
MEL_TARGET_LENGTH = 1024
AUDIOMAE_PATCH_DURATION = 0.16
SEGMENT_OVERLAP_RATIO = 0.0625

class CodecWrapper(nn.Module):
    def __init__(self):
        super(CodecWrapper, self).__init__()
        self.model = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps
        
    def preprocess(self, waveform):
        # waveform: (1, T), 16000Hz
        sr = 16000
        # if stereo to mono
        if waveform.shape[0] > 1:
            waveform = waveform[0:1]
        # Calculate the original duration
        original_duration = waveform.shape[1] / sr
        # This is to pad the audio to the multiplication of 0.16 seconds so that the original audio can be reconstructed
        original_duration = original_duration + (
            AUDIOMAE_PATCH_DURATION - original_duration % AUDIOMAE_PATCH_DURATION
        )
        # Calculate the token length in theory
        target_token_len = (
            8 * original_duration / AUDIOMAE_PATCH_DURATION / self.model.stack_factor_K
        )
        segment_sample_length = int(SAMPLE_RATE * SEGMENT_DURATION)
        # Pad audio to the multiplication of 10.24 seconds for easier segmentations

        if waveform.shape[1] % segment_sample_length < segment_sample_length:
            waveform = torch.cat(
                [
                    waveform,
                    torch.zeros(
                        1,
                        int(
                            segment_sample_length
                            - waveform.shape[1] % segment_sample_length
                        ),
                        device=waveform.device,
                    ),
                ],
                dim=1,
            )

        mel_target_length = MEL_TARGET_LENGTH * int(
            waveform.shape[1] / segment_sample_length
        )
        # Calculate the mel spectrogram
        mel = extract_kaldi_fbank_feature(
            waveform, sr, target_length=mel_target_length
        )["ta_kaldi_fbank"].unsqueeze(0)
        mel = mel.squeeze(1)    # No use
        assert mel.shape[-1] == 128 and mel.shape[-2] % 1024 == 0
        return mel, target_token_len

    def forward(self, wavs):
        # mels = [self.preprocess(wav.unsqueeze(0))[0] for wav in wavs]
        # mels = pad_sequence(mels, batch_first=True)
        # print(mels.shape)
        # # mels: (B, T, D), 16000Hz
        # tokens = self.model.encoder(mels.to(self.model.device))
        # latent = self.model.encoder.token_to_quantized_feature(tokens)
        # return latent
        for wav in wavs:
            mel, target_token_len = self.preprocess(wav.unsqueeze(0))
            tokens = self.model.encoder(mel.to(self.model.device))
            latent = self.model.encoder.token_to_quantized_feature(tokens)
            return latent

In [8]:
import torchaudio

# wav, sr = torchaudio.load(filepath)
# wav = torchaudio.functional.resample(wav, sr, 16000)
# wavs = [wav.squeeze(0).to("cuda")]
wavs = [torch.randn(16000).to("cuda")]

In [None]:
model = CodecWrapper()

🚀 Loading SemantiCodec encoder
✅ Encoder loaded
🚀 Loading SemantiCodec decoder
DiffusionWrapper has 75.04 M params.




✅ Decoder loaded


In [11]:
model.eval()
with torch.no_grad():
    model(wavs)