In [None]:
!git clone https://github.com/FedorZaitsev/VKR25
%cd VKR25

In [None]:
import os
os.environ['TORCH_CUDA_ARCH_LIST']="5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5 8.0 8.6 8.7 8.9 9.0"

!pip install Ninja
!git clone https://github.com/c-hofer/torchph.git
!pip install -e torchph

import sys
sys.path.append("/kaggle/working/VKR25/torchph")

In [None]:
import os

config = {
    'SEED' : 228,
    
    'BOS_TOKEN' : 4096,
    'EOS_TOKEN' : 4097,
    'INP_PAD_TOKEN' : 4098,
    'TAR_PAD_TOKEN' : -100,
    'VOCAB_SIZE' : 4099,
    'MAX_LENGTH' : 256,
    'OVERLAP' : 64,
    
    'NUM_WORKERS' : 4,
    'BATCH_SIZE' : 16,

    'ACCUM_STEPS' : 1,
}

for key, value in config.items():
    os.environ[key] = str(value)

In [None]:
import torch
import random
import numpy as np

device = 'cuda'
root_dir = '/kaggle/input/groove-tokens'

torch.manual_seed(config['SEED'])
random.seed(config['SEED'])
np.random.seed(config['SEED'])

In [None]:
from models.topotransformer_model import TopoTransformerModel, PositionalEncoding, CustomTransformerEncoderLayer
torch.serialization.safe_globals([TopoTransformerModel])

In [None]:
model = torch.load('/kaggle/input/topotransformer/pytorch/default/1/checkpoint_50.pt', map_location=device, weights_only=False)

In [None]:
from tqdm import tqdm

BOS_TOKEN = 4096
EOS_TOKEN = 4097
INP_PAD_TOKEN = 4098

sequence = [BOS_TOKEN]

def generate(model, seq=sequence, max_len=750, tmp=1.0, force=False, watch_tail=256):

    if watch_tail is None:
        watch_tail = 256
    watch_tail = min(256, watch_tail)
    
    device = next(model.parameters()).device
    model.eval()
    generated = seq.copy()

    with torch.no_grad():
        for _ in tqdm(range(max_len - len(seq))):

            src = torch.tensor(generated, dtype=torch.long).unsqueeze(0).to(device)
            if watch_tail is not None:
                src = src[:, -watch_tail:]

            mask = model.generate_square_subsequent_mask(src.shape[1]).to(device)
            output = model(src, mask)
            logits = output[0, -1, :] / tmp
            token = torch.distributions.categorical.Categorical(logits=logits).sample()

            if token.item() > 4095:
                if not force:
                    break
                else:
                    token = torch.distributions.categorical.Categorical(logits=logits[:-3]).sample()
            generated.append(token.item())


    return generated[1:]

In [None]:
!git clone https://github.com/jishengpeng/WavTokenizer
%cd WavTokenizer
!wget https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/resolve/main/wavtokenizer_medium_music_audio_320_24k_v2.ckpt
!wget https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/resolve/main/wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml

In [None]:
from encoder.utils import convert_audio
import torchaudio
import torch
from decoder.pretrained import WavTokenizer

config_path = "/kaggle/working/VKR25/WavTokenizer/wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
model_path = "/kaggle/working/VKR25/WavTokenizer/wavtokenizer_medium_music_audio_320_24k_v2.ckpt"

wavtokenizer = WavTokenizer.from_pretrained0802(config_path, model_path)
wavtokenizer = wavtokenizer.to(device)

In [None]:
def tokens_to_wav(token_seq, name='example.wav'):
    audio_tokens = torch.tensor([token_seq], device=device)
    features = wavtokenizer.codes_to_features(audio_tokens)
    bandwidth_id = torch.tensor([0], device=device)
    audio_out = wavtokenizer.decode(features, bandwidth_id=bandwidth_id)
    torchaudio.save(name, audio_out.cpu(), sample_rate=24000, encoding='PCM_S', bits_per_sample=16)

In [None]:
!mkdir /kaggle/working/VKR25/WavTokenizer/wavs

In [None]:
seq=[4096,  653,  864, 2634, 2613, 2870, 1530, 1351, 3417,  824, 2616, 3417,
        3547,  590, 4093, 3780,  727, 3955, 3780,  311, 2284, 3461, 2511, 2616,
        1615,  270, 3417, 2616, 1492, 3381, 2870, 3497, 1122, 2662, 1122,  707,
        2597, 1615, 3199, 1959,    4,  767, 2628, 2583, 3123, 1877, 2419, 1387,
        2190, 2279,  190, 1050, 1304, 3479,  957,  374, 3706, 2688, 2159, 2481,
        2677, 3021, 2711, 1296,  933, 2801,  760,  417, 2711, 1395,  431, 3253,
        2076,  862, 3503, 2724, 3503,  155, 3327,  599, 3786, 3880, 2012, 3489,
        2017,  916, 3886, 3966, 2619,  635,  123,  699, 1491, 2638,  911,  465,
        1955, 3955,  824,  476,  263, 4093, 2425,  727, 2284, 1615, 3786, 2694,
        2105,  691, 4040, 3620, 3739, 2696, 2012, 2031,  132, 3330,  911,  487,
         694, 2184,  845, 1351,  911,  487, 1842,  590, 2169,  155, 3436,  917,
        1515,  114, 2933, 2811, 2695, 2712, 2233,   19, 3274, 4075,  431,  659,
         694, 1437, 1033, 3977,  862, 2287, 1619, 2499,  280, 3434, 2268, 1293,
         921, 3877, 3801, 2153,  477, 2638, 1810,  654, 2784, 2645,  477, 3427,
        2065, 4093, 4040, 1441,  155, 2496, 4089,  821, 3764, 1711, 2169, 3862,
        2486, 2158,  461, 1421, 3681, 1432, 3212,  674, 1160, 1955, 1503,  864,
        3612, 2240,  465, 3862, 1437,  155, 2588, 3461, 2699, 1351, 2070, 2042,
        1028, 3955, 2731, 2499,  590, 2616, 2287,  700, 2070, 2042,  311, 3845,
        3417, 3461, 2616, 2616,   81, 2902, 2042,  590,  590, 3845,  270, 2597,
        3207, 4035, 3896, 3788,  251,  348, 1552,  629, 1906,  629, 2734, 2811,
         496, 2957, 4095, 2423, 1529, 3207, 3909, 3528, 2771, 1880,  837, 1039,
        1988, 3969, 1955]

In [None]:
for i in range(100):
    ans1 = generate(model, seq=seq, tmp=1, force=True, watch_tail=250)
    tokens_to_wav(ans1, f'wavs/sample{i}.wav')

In [None]:
import shutil
shutil.make_archive('wavs', 'zip', '/kaggle/working/VKR25/WavTokenizer/wavs')