In [1]:
import torch
import torch.nn as nn
VOCAB_SIZE = 4099
MAX_LENGTH = 64

device=torch.device('cuda')

In [2]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, max_length=MAX_LENGTH, embed_size=256, hidden_size=256, rnn_type=nn.LSTM, rnn_layers=1):
        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=embed_size)
        nn.init.xavier_uniform_(self.embedding.weight)

        self.embed_ln = nn.LayerNorm(embed_size)
        self.embed_dropout = nn.Dropout(0.1)
        
        self.rnn = rnn_type(input_size=embed_size, 
                            hidden_size=hidden_size, 
                            num_layers=rnn_layers, 
                            batch_first=True)
        
        for name, param in self.named_parameters():
            if 'weight_hh' in name:  # Recurrent weights
                torch.nn.init.orthogonal_(param)
            elif 'weight_ih' in name:  # Input weights
                torch.nn.init.xavier_normal_(param)

        for name, param in self.rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
                # LSTM forget gate bias = 1
                n = param.size(0)
                param.data[n//4:n//2].fill_(1)

        self.rnn_ln = nn.LayerNorm(hidden_size)
        
        self.linear = nn.Linear(hidden_size, VOCAB_SIZE)
        nn.init.xavier_uniform(self.linear.weight)

    def forward(self, x):
        x = self.embedding(x)
        x = self.embed_ln(x)
        x = self.embed_dropout(x)
        x, _ = self.rnn(x)
        x = self.rnn_ln(x)
        return self.linear(x)

In [3]:
model100 = torch.load('/kaggle/input/medium_rnn/pytorch/default/1/checkpoint_100.pt', map_location=device)
model200 = torch.load('/kaggle/input/medium_rnn/pytorch/default/1/checkpoint_200.pt', map_location=device)
model300 = torch.load('/kaggle/input/medium_rnn/pytorch/default/1/checkpoint_300.pt', map_location=device)

  model100 = torch.load('/kaggle/input/medium_rnn/pytorch/default/1/checkpoint_100.pt', map_location=device)
  model200 = torch.load('/kaggle/input/medium_rnn/pytorch/default/1/checkpoint_200.pt', map_location=device)
  model300 = torch.load('/kaggle/input/medium_rnn/pytorch/default/1/checkpoint_300.pt', map_location=device)


In [4]:
from tqdm import tqdm

BOS_TOKEN = 4096
EOS_TOKEN = 4097
INP_PAD_TOKEN = 4098

sequence = [BOS_TOKEN]

def generate(model, seq=sequence, max_len=600, tmp=1.0, force=False, watch_tail=None):

    device = next(model.parameters()).device
    model.eval()
    generated = seq.copy()

    with torch.no_grad():
        for _ in tqdm(range(max_len - len(seq))):

            src = torch.tensor(generated, dtype=torch.long).unsqueeze(0).to(device)
            if watch_tail is not None:
                src = src[:, -watch_tail:]
            output = model(src)
            logits = output[0, -1, :] / tmp
            token = torch.distributions.categorical.Categorical(logits=logits).sample()

            if token.item() > 4095:
                if not force:
                    break
                else:
                    token = torch.distributions.categorical.Categorical(logits=logits[:-3]).sample()
            generated.append(token.item())


    return generated[1:]

In [5]:
!git clone https://github.com/jishengpeng/WavTokenizer
%cd WavTokenizer

Cloning into 'WavTokenizer'...
remote: Enumerating objects: 200, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 200 (delta 65), reused 37 (delta 35), pack-reused 84 (from 1)[K
Receiving objects: 100% (200/200), 469.31 KiB | 5.46 MiB/s, done.
Resolving deltas: 100% (83/83), done.
/kaggle/working/WavTokenizer


In [6]:
!wget https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/resolve/main/wavtokenizer_medium_music_audio_320_24k_v2.ckpt
!wget https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/resolve/main/wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml

--2025-04-21 22:37:16--  https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/resolve/main/wavtokenizer_medium_music_audio_320_24k_v2.ckpt
Resolving huggingface.co (huggingface.co)... 3.166.152.65, 3.166.152.105, 3.166.152.44, ...
Connecting to huggingface.co (huggingface.co)|3.166.152.65|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/f8/d8/f8d8b97f33126a1e3a4c3ffe2e6af86c26776cfa33aee1294672329b62115562/078d11581aa10cc91572bfcff7ef00b71d8e24c4c359e98f9194a1a9d21ae8a8?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27wavtokenizer_medium_music_audio_320_24k_v2.ckpt%3B+filename%3D%22wavtokenizer_medium_music_audio_320_24k_v2.ckpt%22%3B&Expires=1745278636&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NTI3ODYzNn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2Y4L2Q4L2Y4ZDhiOTdmMzMxMjZhMWUzYTRjM2ZmZTJlNmFmODZjMjY3NzZjZmEzM2FlZTEyOTQ2

In [7]:
from encoder.utils import convert_audio
import torchaudio
import torch
from decoder.pretrained import WavTokenizer

config_path = "/kaggle/working/WavTokenizer/wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
model_path = "/kaggle/working/WavTokenizer/wavtokenizer_medium_music_audio_320_24k_v2.ckpt"

wavtokenizer = WavTokenizer.from_pretrained0802(config_path, model_path)
wavtokenizer = wavtokenizer.to(device)

  WeightNorm.apply(module, name, dim)
  state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']


In [8]:
def tokens_to_wav(token_seq, name='example.wav'):
    audio_tokens = torch.tensor([token_seq], device=device)
    features = wavtokenizer.codes_to_features(audio_tokens)
    bandwidth_id = torch.tensor([0], device=device)
    audio_out = wavtokenizer.decode(features, bandwidth_id=bandwidth_id)
    torchaudio.save(name, audio_out.cpu(), sample_rate=24000, encoding='PCM_S', bits_per_sample=16)

In [9]:
ans = generate(model100, tmp=1, force=True, watch_tail=40)
tokens_to_wav(ans)

100%|██████████| 599/599 [00:02<00:00, 236.81it/s]


In [10]:
!mkdir /kaggle/working/WavTokenizer/wavs

In [11]:
for i, tmp in enumerate([0.01, 0.1, 1, 3, 5, 10]):
    for tail in [None, 10, 20, 30, 40, 50, 60]:
        ans1 = generate(model100, tmp=tmp, force=True, watch_tail=tail)
        ans2 = generate(model200, tmp=tmp, force=True, watch_tail=tail)
        ans3 = generate(model300, tmp=tmp, force=True, watch_tail=tail)
        tokens_to_wav(ans1, f'wavs/model1_tmp{i}_tail{tail}.wav')
        tokens_to_wav(ans2, f'wavs/model2_tmp{i}_tail{tail}.wav')
        tokens_to_wav(ans3, f'wavs/model3_tmp{i}_tail{tail}.wav')

100%|██████████| 599/599 [00:11<00:00, 50.89it/s]
100%|██████████| 599/599 [00:11<00:00, 51.46it/s]
100%|██████████| 599/599 [00:11<00:00, 51.56it/s]
100%|██████████| 599/599 [00:00<00:00, 703.56it/s]
100%|██████████| 599/599 [00:00<00:00, 685.91it/s]
100%|██████████| 599/599 [00:00<00:00, 680.88it/s]
100%|██████████| 599/599 [00:01<00:00, 473.73it/s]
100%|██████████| 599/599 [00:01<00:00, 473.08it/s]
100%|██████████| 599/599 [00:01<00:00, 477.72it/s]
100%|██████████| 599/599 [00:01<00:00, 347.06it/s]
100%|██████████| 599/599 [00:01<00:00, 358.57it/s]
100%|██████████| 599/599 [00:01<00:00, 357.27it/s]
100%|██████████| 599/599 [00:02<00:00, 292.45it/s]
100%|██████████| 599/599 [00:02<00:00, 294.12it/s]
100%|██████████| 599/599 [00:02<00:00, 298.16it/s]
100%|██████████| 599/599 [00:02<00:00, 246.78it/s]
100%|██████████| 599/599 [00:02<00:00, 249.13it/s]
100%|██████████| 599/599 [00:02<00:00, 249.05it/s]
100%|██████████| 599/599 [00:02<00:00, 214.98it/s]
100%|██████████| 599/599 [00:02<00

In [12]:
import shutil
shutil.make_archive('wavs', 'zip', '/kaggle/working/WavTokenizer/wavs')

'/kaggle/working/WavTokenizer/wavs.zip'