In [1]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

Mon Jul 24 15:23:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000001:00:00.0 Off |                    0 |
| N/A   32C    P0    42W / 300W |      0MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
#%pip install audiolm-pytorch
#%pip install pydub
#%pip install scikit-learn==0.24.0

## Setup

### Imports & paths

In [1]:
# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio


# define all dataset paths, checkpoints, etc
dataset_folder = "audio/wav"
soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

## Inference

In [13]:
wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
).cuda()

soundstream = SoundStream.init_and_load_from("./results/soundstream.8.pt").cuda()

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
).cuda()

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
).cuda()

In [14]:
audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

#generated_wav = audiolm(batch_size = 1)
# or with priming
#generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))
# or with text condition, if given
generated_wav_with_text_condition = audiolm(text = ['Hola, mi nombre es Ingrid. Que tal?'])

generating semantic:   5%|▌         | 106/2048 [00:00<00:08, 222.17it/s]
generating coarse: 100%|██████████| 512/512 [00:10<00:00, 47.30it/s]
generating fine: 100%|██████████| 512/512 [01:42<00:00,  4.98it/s]


In [17]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav_with_text_condition.cpu(), sample_rate)