In [4]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

Sat Feb 18 21:56:31 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   44C    P8    15W /  60W |      8MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
!pip install audiolm-pytorch



## Setup

### Imports & paths

In [1]:
# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio


# define all dataset paths, checkpoints, etc
dataset_folder = "../../datasets/fma_medium/"
soundstream_ckpt = "runs/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

  from .autonotebook import tqdm as notebook_tqdm
2023-02-18 23:04:25 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


## Training

Now that we have a dataset, we can train AudioLM.

**Note**: do NOT type "y" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose "overwrite" then you lose the SoundStream checkpoint when you then train SemanticTransformer).

### SoundStream

In [4]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 12,
    target_sample_hz = 24000,
    strides = (3, 4, 5, 8)
)

# soundstream_path = "runs/soundstream(0.051).67.pt"
# soundstream.load_from_trainer_saved_obj(f"./{soundstream_path}")  #Load pretrained

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 480 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 20000,
    random_split_seed=5670
).cuda()

trainer.train()

training with dataset of 23730 samples and validating with randomly splitted 1249 samples
0: soundstream total loss: 47903.944, soundstream recon loss: 0.097 | discr (scale 1) loss: 0.011 | discr (scale 0.5) loss: 0.005 | discr (scale 0.25) loss: 0.000
0: saving model to results
0: saving to results
1: soundstream total loss: 40633.139, soundstream recon loss: 0.070 | discr (scale 1) loss: 0.003 | discr (scale 0.5) loss: 0.012 | discr (scale 0.25) loss: 0.000
1: saving model to results
1: saving to results
2: soundstream total loss: 45552.765, soundstream recon loss: 0.101 | discr (scale 1) loss: 0.006 | discr (scale 0.5) loss: 0.000 | discr (scale 0.25) loss: 0.000
3: soundstream total loss: 49481.227, soundstream recon loss: 0.103 | discr (scale 1) loss: 0.017 | discr (scale 0.5) loss: 0.033 | discr (scale 0.25) loss: 0.008


### SoundStream Check

In [2]:
from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.trainer import cycle

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)
soundstream_path = "runs/soundstream.2000.pt"
soundstream.load_from_trainer_saved_obj(f"./{soundstream_path}")  #Load ckpt for test

soundstream.eval()
soundstream = soundstream.to("cuda")

ds = SoundDataset(
            folder = dataset_folder,
            max_length = 24000*15,  #15 Sec
            target_sample_hz = soundstream.target_sample_hz,
            seq_len_multiple_of = soundstream.seq_len_multiple_of
)

dl = get_dataloader(ds, batch_size = 1, num_workers = 0, shuffle = True)

dl_iter = cycle(dl)

wave, = next(dl_iter)
wave = wave.to("cuda")

In [3]:
from pathlib import Path

results_folder = Path('./results_test')
results_folder.mkdir(exist_ok=True)

filename = str(results_folder / f'orig_{0}.wav')
torchaudio.save(filename, wave.cpu(), soundstream.target_sample_hz)

with torch.no_grad():
    recons = soundstream(wave, return_recons_only = True)

for ind, recon in enumerate(recons.unbind(dim = 0)):
    filename = str(results_folder / f'sample_{ind}.wav')
    torchaudio.save(filename, recon.cpu().detach(), soundstream.target_sample_hz)

### SemanticTransformer

In [18]:
# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


training with dataset of 2 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) n
0: loss: 6.648584365844727
0: valid loss 5.763116359710693
0: saving model to results
training complete


### CoarseTransformer

In [19]:
wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


training with dataset of 2 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) n
0: loss: 63.983970642089844
0: valid loss 63.398582458496094
0: saving model to results
1: loss: 65.85967254638672
2: loss: 62.4722900390625
2: valid loss 50.01605987548828
3: loss: 11.735434532165527
4: loss: 3.976104497909546
4: valid loss 46.094608306884766
4: saving model to results
5: loss: 58.27140426635742
6: loss: 41.68347930908203
6: valid loss 45.54595184326172
7: loss: 2.2387890815734863
8: loss: 0.4718627631664276
8: valid loss 39.10848617553711
8: saving model to results
training complete


### FineTransformer

In [20]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

training with dataset of 2 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) n
0: loss: 70.90608215332031
0: valid loss 65.99951171875
0: saving model to results
1: loss: 43.6014289855957
2: loss: 8.300681114196777
3: loss: 61.23375701904297
4: loss: 63.34052276611328
5: loss: 2.010118246078491
6: loss: 56.52588653564453
7: loss: 0.5423888564109802
8: loss: 0.005095238331705332
training complete


## Inference

In [21]:
# Everything together
audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

generating semantic:   0%|          | 10/2048 [00:00<00:25, 78.55it/s]
generating coarse: 100%|██████████| 512/512 [00:14<00:00, 34.83it/s]
generating fine: 100%|██████████| 512/512 [02:56<00:00,  2.91it/s]


In [22]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)