In [1]:
import shutil
from tqdm.auto import tqdm
from pathlib import Path
# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


## Using only clean records from librispeech dataset

In [2]:
libtrispeech_path = Path("../librispeech_finetuning")
assert libtrispeech_path.exists(), "Must be exists!"

marker = "clean"
destination_dir = Path(f"../librispeech_{marker}_flac")
if not destination_dir.exists():
    destination_dir.mkdir(parents=True)

In [3]:
for flac in tqdm(libtrispeech_path.glob("**/*.flac")):
    if marker in str(flac):
        shutil.copy(flac, Path(destination_dir, flac.name))

2763it [00:37, 74.66it/s]  


In [4]:
# convert flac to wavs (!)
# !pip install AudioConverter
# !audioconvert convert  vall-e/data/libri --output-format .wav

In [5]:
destination_wavs_dir = Path("../librispeech_wavs")
assert destination_wavs_dir.exists(), "Must be exists!"

In [17]:
# define all dataset paths, checkpoints, etc
dataset_folder = str(destination_wavs_dir)

# soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
# hubert_ckpt = 'hubert/hubert_base_ls960.pt'
# hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

# soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
soundstream_ckpt = "results/soundstream.0.pt" # this can change depending on number of steps

hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' 
# listed in row "HuBERT Base (~95M params)", column Quantizer

In [7]:
# Placeholder data generation
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):
  # code adapted from https://stackoverflow.com/a/33913403
  audio = []
  num_samples = duration_ms * (sample_rate / 1000.0)
  for x in range(int(num_samples)):
    audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))
  return audio

def save_wav(file_name, audio, sample_rate=44100.0):
  # Open up a wav file
  wav_file=wave.open(file_name,"w")
  # wav params
  nchannels = 1
  sampwidth = 2
  # 44100 is the industry standard sample rate - CD quality.  If you need to
  # save on file size you can adjust it downwards. The stanard for low quality
  # is 8000 or 8kHz.
  nframes = len(audio)
  comptype = "NONE"
  compname = "not compressed"
  wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))
  # WAV files here are using short, 16 bit, signed integers for the 
  # sample size.  So we multiply the floating point data we have by 32767, the
  # maximum value for a short integer.  NOTE: It is theortically possible to
  # use the floating point -1.0 to 1.0 data directly in a WAV file but not
  # obvious how to do that using the wave module in python.
  for sample in audio:
      wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))
  wav_file.close()
  return

def make_placeholder_dataset():
  # Make a placeholder dataset with a few .wav files that you can "train" on, just to verify things work e2e
  if os.path.isdir(dataset_folder):
    return
  os.makedirs(dataset_folder)
  save_wav(f"{dataset_folder}/example.wav", get_sinewave())
  save_wav(f"{dataset_folder}/example2.wav", get_sinewave(duration_ms=500))
  os.makedirs(f"{dataset_folder}/subdirectory")
  save_wav(f"{dataset_folder}/subdirectory/example.wav", get_sinewave(freq=330.0))

make_placeholder_dataset()

In [8]:
# Get actual dataset. Uncomment this if you want to try training on real data

# full dataset: https://www.openslr.org/12
# We'll use https://us.openslr.org/resources/12/dev-clean.tar.gz development set, "clean" speech.
# We *should* train on, well, training, but this is just to demo running things end-to-end at all so I just picked a small clean set.

# url = "https://us.openslr.org/resources/12/dev-clean.tar.gz"
# filename = "dev-clean"
# filename_targz = filename + ".tar.gz"
# if not os.path.isfile(filename_targz):
#   urllib.request.urlretrieve(url, filename_targz)
# if not os.path.isdir(filename):
#   # open file
#   with tarfile.open(filename_targz) as t:
#     t.extractall(filename)

## Training

Now that we have a dataset, we can train AudioLM.

**Note**: do NOT type "y" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose "overwrite" then you lose the SoundStream checkpoint when you then train SemanticTransformer).

In [9]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    batch_size = 2,  # effective batch size of 32
    grad_accum_every = 8,         
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
).cuda()

# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

training with dataset of 1295 samples and validating with randomly splitted 69 samples


do you want to clear previous experiment checkpoints and results? (y/n)  y


0: soundstream total loss: 29.295, soundstream recon loss: 0.019 | discr (scale 1) loss: 2.000 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
0: saving to results
0: saving model to results
1: soundstream total loss: 23.763, soundstream recon loss: 0.008 | discr (scale 1) loss: 2.001 | discr (scale 0.5) loss: 2.001 | discr (scale 0.25) loss: 2.000
2: soundstream total loss: 22.886, soundstream recon loss: 0.004 | discr (scale 1) loss: 1.999 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
2: saving to results
3: soundstream total loss: 34.897, soundstream recon loss: 0.011 | discr (scale 1) loss: 1.998 | discr (scale 0.5) loss: 1.999 | discr (scale 0.25) loss: 1.999
4: soundstream total loss: 23.379, soundstream recon loss: 0.006 | discr (scale 1) loss: 1.998 | discr (scale 0.5) loss: 1.999 | discr (scale 0.25) loss: 1.999
4: saving to results
4: saving model to results
5: soundstream total loss: 22.289, soundstream recon loss: 0.004 | discr (scale 1) 

In [10]:
# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

Downloading (…)lve/main/config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 605/605 [00:00<00:00, 3.67MB/s]


training with dataset of 1295 samples and validating with randomly splitted 69 samples


do you want to clear previous experiment checkpoints and results? (y/n)  y


Computing label assignment and total inertia
0: loss: 6.317939281463623
Computing label assignment and total inertia
0: valid loss 6.484222888946533
0: saving model to results
training complete


In [23]:
wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

# semantic.transformer.0.pt
soundstream_ckpt = Path("./results/semantic.transformer.0.pt")
assert soundstream_ckpt.exists()
soundstream.load(str(soundstream_ckpt))

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

RuntimeError: Error(s) in loading state_dict for SoundStream:
	Missing key(s) in state_dict: "encoder.0.conv.weight", "encoder.0.conv.bias", "encoder.1.0.fn.0.conv.weight", "encoder.1.0.fn.0.conv.bias", "encoder.1.0.fn.2.conv.weight", "encoder.1.0.fn.2.conv.bias", "encoder.1.1.fn.0.conv.weight", "encoder.1.1.fn.0.conv.bias", "encoder.1.1.fn.2.conv.weight", "encoder.1.1.fn.2.conv.bias", "encoder.1.2.fn.0.conv.weight", "encoder.1.2.fn.0.conv.bias", "encoder.1.2.fn.2.conv.weight", "encoder.1.2.fn.2.conv.bias", "encoder.1.3.conv.weight", "encoder.1.3.conv.bias", "encoder.2.0.fn.0.conv.weight", "encoder.2.0.fn.0.conv.bias", "encoder.2.0.fn.2.conv.weight", "encoder.2.0.fn.2.conv.bias", "encoder.2.1.fn.0.conv.weight", "encoder.2.1.fn.0.conv.bias", "encoder.2.1.fn.2.conv.weight", "encoder.2.1.fn.2.conv.bias", "encoder.2.2.fn.0.conv.weight", "encoder.2.2.fn.0.conv.bias", "encoder.2.2.fn.2.conv.weight", "encoder.2.2.fn.2.conv.bias", "encoder.2.3.conv.weight", "encoder.2.3.conv.bias", "encoder.3.0.fn.0.conv.weight", "encoder.3.0.fn.0.conv.bias", "encoder.3.0.fn.2.conv.weight", "encoder.3.0.fn.2.conv.bias", "encoder.3.1.fn.0.conv.weight", "encoder.3.1.fn.0.conv.bias", "encoder.3.1.fn.2.conv.weight", "encoder.3.1.fn.2.conv.bias", "encoder.3.2.fn.0.conv.weight", "encoder.3.2.fn.0.conv.bias", "encoder.3.2.fn.2.conv.weight", "encoder.3.2.fn.2.conv.bias", "encoder.3.3.conv.weight", "encoder.3.3.conv.bias", "encoder.4.0.fn.0.conv.weight", "encoder.4.0.fn.0.conv.bias", "encoder.4.0.fn.2.conv.weight", "encoder.4.0.fn.2.conv.bias", "encoder.4.1.fn.0.conv.weight", "encoder.4.1.fn.0.conv.bias", "encoder.4.1.fn.2.conv.weight", "encoder.4.1.fn.2.conv.bias", "encoder.4.2.fn.0.conv.weight", "encoder.4.2.fn.0.conv.bias", "encoder.4.2.fn.2.conv.weight", "encoder.4.2.fn.2.conv.bias", "encoder.4.3.conv.weight", "encoder.4.3.conv.bias", "encoder.5.conv.weight", "encoder.5.conv.bias", "encoder_attn.layers.0.0.q_scale", "encoder_attn.layers.0.0.k_scale", "encoder_attn.layers.0.0.norm.weight", "encoder_attn.layers.0.0.norm.bias", "encoder_attn.layers.0.0.to_qkv.weight", "encoder_attn.layers.0.0.attn_fn.rel_pos.inv_freq", "encoder_attn.layers.0.0.to_out.weight", "encoder_attn.layers.0.1.0.weight", "encoder_attn.layers.0.1.0.bias", "encoder_attn.layers.0.1.1.weight", "encoder_attn.layers.0.1.4.weight", "encoder_film.to_cond.weight", "encoder_film.to_cond.bias", "rq.rvqs.0.layers.0._codebook.initted", "rq.rvqs.0.layers.0._codebook.cluster_size", "rq.rvqs.0.layers.0._codebook.embed_avg", "rq.rvqs.0.layers.0._codebook.embed", "rq.rvqs.0.layers.1._codebook.initted", "rq.rvqs.0.layers.1._codebook.cluster_size", "rq.rvqs.0.layers.1._codebook.embed_avg", "rq.rvqs.0.layers.1._codebook.embed", "rq.rvqs.0.layers.2._codebook.initted", "rq.rvqs.0.layers.2._codebook.cluster_size", "rq.rvqs.0.layers.2._codebook.embed_avg", "rq.rvqs.0.layers.2._codebook.embed", "rq.rvqs.0.layers.3._codebook.initted", "rq.rvqs.0.layers.3._codebook.cluster_size", "rq.rvqs.0.layers.3._codebook.embed_avg", "rq.rvqs.0.layers.3._codebook.embed", "rq.rvqs.0.layers.4._codebook.initted", "rq.rvqs.0.layers.4._codebook.cluster_size", "rq.rvqs.0.layers.4._codebook.embed_avg", "rq.rvqs.0.layers.4._codebook.embed", "rq.rvqs.0.layers.5._codebook.initted", "rq.rvqs.0.layers.5._codebook.cluster_size", "rq.rvqs.0.layers.5._codebook.embed_avg", "rq.rvqs.0.layers.5._codebook.embed", "rq.rvqs.0.layers.6._codebook.initted", "rq.rvqs.0.layers.6._codebook.cluster_size", "rq.rvqs.0.layers.6._codebook.embed_avg", "rq.rvqs.0.layers.6._codebook.embed", "rq.rvqs.0.layers.7._codebook.initted", "rq.rvqs.0.layers.7._codebook.cluster_size", "rq.rvqs.0.layers.7._codebook.embed_avg", "rq.rvqs.0.layers.7._codebook.embed", "decoder_film.to_cond.weight", "decoder_film.to_cond.bias", "decoder_attn.layers.0.0.q_scale", "decoder_attn.layers.0.0.k_scale", "decoder_attn.layers.0.0.norm.weight", "decoder_attn.layers.0.0.norm.bias", "decoder_attn.layers.0.0.to_qkv.weight", "decoder_attn.layers.0.0.attn_fn.rel_pos.inv_freq", "decoder_attn.layers.0.0.to_out.weight", "decoder_attn.layers.0.1.0.weight", "decoder_attn.layers.0.1.0.bias", "decoder_attn.layers.0.1.1.weight", "decoder_attn.layers.0.1.4.weight", "decoder.0.conv.weight", "decoder.0.conv.bias", "decoder.1.0.conv.weight", "decoder.1.0.conv.bias", "decoder.1.1.fn.0.conv.weight", "decoder.1.1.fn.0.conv.bias", "decoder.1.1.fn.2.conv.weight", "decoder.1.1.fn.2.conv.bias", "decoder.1.2.fn.0.conv.weight", "decoder.1.2.fn.0.conv.bias", "decoder.1.2.fn.2.conv.weight", "decoder.1.2.fn.2.conv.bias", "decoder.1.3.fn.0.conv.weight", "decoder.1.3.fn.0.conv.bias", "decoder.1.3.fn.2.conv.weight", "decoder.1.3.fn.2.conv.bias", "decoder.2.0.conv.weight", "decoder.2.0.conv.bias", "decoder.2.1.fn.0.conv.weight", "decoder.2.1.fn.0.conv.bias", "decoder.2.1.fn.2.conv.weight", "decoder.2.1.fn.2.conv.bias", "decoder.2.2.fn.0.conv.weight", "decoder.2.2.fn.0.conv.bias", "decoder.2.2.fn.2.conv.weight", "decoder.2.2.fn.2.conv.bias", "decoder.2.3.fn.0.conv.weight", "decoder.2.3.fn.0.conv.bias", "decoder.2.3.fn.2.conv.weight", "decoder.2.3.fn.2.conv.bias", "decoder.3.0.conv.weight", "decoder.3.0.conv.bias", "decoder.3.1.fn.0.conv.weight", "decoder.3.1.fn.0.conv.bias", "decoder.3.1.fn.2.conv.weight", "decoder.3.1.fn.2.conv.bias", "decoder.3.2.fn.0.conv.weight", "decoder.3.2.fn.0.conv.bias", "decoder.3.2.fn.2.conv.weight", "decoder.3.2.fn.2.conv.bias", "decoder.3.3.fn.0.conv.weight", "decoder.3.3.fn.0.conv.bias", "decoder.3.3.fn.2.conv.weight", "decoder.3.3.fn.2.conv.bias", "decoder.4.0.conv.weight", "decoder.4.0.conv.bias", "decoder.4.1.fn.0.conv.weight", "decoder.4.1.fn.0.conv.bias", "decoder.4.1.fn.2.conv.weight", "decoder.4.1.fn.2.conv.bias", "decoder.4.2.fn.0.conv.weight", "decoder.4.2.fn.0.conv.bias", "decoder.4.2.fn.2.conv.weight", "decoder.4.2.fn.2.conv.bias", "decoder.4.3.fn.0.conv.weight", "decoder.4.3.fn.0.conv.bias", "decoder.4.3.fn.2.conv.weight", "decoder.4.3.fn.2.conv.bias", "decoder.5.conv.weight", "decoder.5.conv.bias", "discriminators.0.init_conv.weight", "discriminators.0.init_conv.bias", "discriminators.0.conv_layers.0.0.weight", "discriminators.0.conv_layers.0.0.bias", "discriminators.0.conv_layers.1.0.weight", "discriminators.0.conv_layers.1.0.bias", "discriminators.0.conv_layers.2.0.weight", "discriminators.0.conv_layers.2.0.bias", "discriminators.0.conv_layers.3.0.weight", "discriminators.0.conv_layers.3.0.bias", "discriminators.0.final_conv.0.weight", "discriminators.0.final_conv.0.bias", "discriminators.0.final_conv.2.weight", "discriminators.0.final_conv.2.bias", "discriminators.1.init_conv.weight", "discriminators.1.init_conv.bias", "discriminators.1.conv_layers.0.0.weight", "discriminators.1.conv_layers.0.0.bias", "discriminators.1.conv_layers.1.0.weight", "discriminators.1.conv_layers.1.0.bias", "discriminators.1.conv_layers.2.0.weight", "discriminators.1.conv_layers.2.0.bias", "discriminators.1.conv_layers.3.0.weight", "discriminators.1.conv_layers.3.0.bias", "discriminators.1.final_conv.0.weight", "discriminators.1.final_conv.0.bias", "discriminators.1.final_conv.2.weight", "discriminators.1.final_conv.2.bias", "discriminators.2.init_conv.weight", "discriminators.2.init_conv.bias", "discriminators.2.conv_layers.0.0.weight", "discriminators.2.conv_layers.0.0.bias", "discriminators.2.conv_layers.1.0.weight", "discriminators.2.conv_layers.1.0.bias", "discriminators.2.conv_layers.2.0.weight", "discriminators.2.conv_layers.2.0.bias", "discriminators.2.conv_layers.3.0.weight", "discriminators.2.conv_layers.3.0.bias", "discriminators.2.final_conv.0.weight", "discriminators.2.final_conv.0.bias", "discriminators.2.final_conv.2.weight", "discriminators.2.final_conv.2.bias", "stft_discriminator.init_conv.weight", "stft_discriminator.init_conv.bias", "stft_discriminator.layers.0.0.fn.0.weight", "stft_discriminator.layers.0.0.fn.0.bias", "stft_discriminator.layers.0.0.fn.1.b", "stft_discriminator.layers.0.0.fn.2.weight", "stft_discriminator.layers.0.0.fn.2.bias", "stft_discriminator.layers.0.1.weight", "stft_discriminator.layers.0.1.bias", "stft_discriminator.layers.1.0.fn.0.weight", "stft_discriminator.layers.1.0.fn.0.bias", "stft_discriminator.layers.1.0.fn.1.b", "stft_discriminator.layers.1.0.fn.2.weight", "stft_discriminator.layers.1.0.fn.2.bias", "stft_discriminator.layers.1.1.weight", "stft_discriminator.layers.1.1.bias", "stft_discriminator.layers.2.0.fn.0.weight", "stft_discriminator.layers.2.0.fn.0.bias", "stft_discriminator.layers.2.0.fn.1.b", "stft_discriminator.layers.2.0.fn.2.weight", "stft_discriminator.layers.2.0.fn.2.bias", "stft_discriminator.layers.2.1.weight", "stft_discriminator.layers.2.1.bias", "stft_discriminator.layers.3.0.fn.0.weight", "stft_discriminator.layers.3.0.fn.0.bias", "stft_discriminator.layers.3.0.fn.1.b", "stft_discriminator.layers.3.0.fn.2.weight", "stft_discriminator.layers.3.0.fn.2.bias", "stft_discriminator.layers.3.1.weight", "stft_discriminator.layers.3.1.bias", "stft_discriminator.layers.4.0.fn.0.weight", "stft_discriminator.layers.4.0.fn.0.bias", "stft_discriminator.layers.4.0.fn.1.b", "stft_discriminator.layers.4.0.fn.2.weight", "stft_discriminator.layers.4.0.fn.2.bias", "stft_discriminator.layers.4.1.weight", "stft_discriminator.layers.4.1.bias", "stft_discriminator.layers.5.0.fn.0.weight", "stft_discriminator.layers.5.0.fn.0.bias", "stft_discriminator.layers.5.0.fn.1.b", "stft_discriminator.layers.5.0.fn.2.weight", "stft_discriminator.layers.5.0.fn.2.bias", "stft_discriminator.layers.5.1.weight", "stft_discriminator.layers.5.1.bias", "stft_discriminator.final_conv.weight", "stft_discriminator.final_conv.bias", "mel_spec_transforms.0.spectrogram.window", "mel_spec_transforms.0.mel_scale.fb", "mel_spec_transforms.1.spectrogram.window", "mel_spec_transforms.1.mel_scale.fb", "mel_spec_transforms.2.spectrogram.window", "mel_spec_transforms.2.mel_scale.fb", "mel_spec_transforms.3.spectrogram.window", "mel_spec_transforms.3.mel_scale.fb", "mel_spec_transforms.4.spectrogram.window", "mel_spec_transforms.4.mel_scale.fb", "mel_spec_transforms.5.spectrogram.window", "mel_spec_transforms.5.mel_scale.fb". 
	Unexpected key(s) in state_dict: "start_token", "semantic_embedding.weight", "proj_text_embed.weight", "transformer.layers.0.0.norm.gamma", "transformer.layers.0.0.norm.beta", "transformer.layers.0.0.to_q.weight", "transformer.layers.0.0.to_kv.weight", "transformer.layers.0.0.to_out.0.weight", "transformer.layers.0.2.0.gamma", "transformer.layers.0.2.0.beta", "transformer.layers.0.2.1.weight", "transformer.layers.0.2.3.gamma", "transformer.layers.0.2.3.beta", "transformer.layers.0.2.5.weight", "transformer.layers.1.0.norm.gamma", "transformer.layers.1.0.norm.beta", "transformer.layers.1.0.to_q.weight", "transformer.layers.1.0.to_kv.weight", "transformer.layers.1.0.to_out.0.weight", "transformer.layers.1.2.0.gamma", "transformer.layers.1.2.0.beta", "transformer.layers.1.2.1.weight", "transformer.layers.1.2.3.gamma", "transformer.layers.1.2.3.beta", "transformer.layers.1.2.5.weight", "transformer.layers.2.0.norm.gamma", "transformer.layers.2.0.norm.beta", "transformer.layers.2.0.to_q.weight", "transformer.layers.2.0.to_kv.weight", "transformer.layers.2.0.to_out.0.weight", "transformer.layers.2.2.0.gamma", "transformer.layers.2.2.0.beta", "transformer.layers.2.2.1.weight", "transformer.layers.2.2.3.gamma", "transformer.layers.2.2.3.beta", "transformer.layers.2.2.5.weight", "transformer.layers.3.0.norm.gamma", "transformer.layers.3.0.norm.beta", "transformer.layers.3.0.to_q.weight", "transformer.layers.3.0.to_kv.weight", "transformer.layers.3.0.to_out.0.weight", "transformer.layers.3.2.0.gamma", "transformer.layers.3.2.0.beta", "transformer.layers.3.2.1.weight", "transformer.layers.3.2.3.gamma", "transformer.layers.3.2.3.beta", "transformer.layers.3.2.5.weight", "transformer.layers.4.0.norm.gamma", "transformer.layers.4.0.norm.beta", "transformer.layers.4.0.to_q.weight", "transformer.layers.4.0.to_kv.weight", "transformer.layers.4.0.to_out.0.weight", "transformer.layers.4.2.0.gamma", "transformer.layers.4.2.0.beta", "transformer.layers.4.2.1.weight", "transformer.layers.4.2.3.gamma", "transformer.layers.4.2.3.beta", "transformer.layers.4.2.5.weight", "transformer.layers.5.0.norm.gamma", "transformer.layers.5.0.norm.beta", "transformer.layers.5.0.to_q.weight", "transformer.layers.5.0.to_kv.weight", "transformer.layers.5.0.to_out.0.weight", "transformer.layers.5.2.0.gamma", "transformer.layers.5.2.0.beta", "transformer.layers.5.2.1.weight", "transformer.layers.5.2.3.gamma", "transformer.layers.5.2.3.beta", "transformer.layers.5.2.5.weight", "transformer.rel_pos_bias.net.0.0.weight", "transformer.rel_pos_bias.net.0.0.bias", "transformer.rel_pos_bias.net.1.0.weight", "transformer.rel_pos_bias.net.1.0.bias", "transformer.rel_pos_bias.net.2.0.weight", "transformer.rel_pos_bias.net.2.0.bias", "transformer.rel_pos_bias.net.3.weight", "transformer.rel_pos_bias.net.3.bias", "transformer.norm.gamma", "transformer.norm.beta", "to_logits.weight", "to_logits.bias". 

In [24]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

RuntimeError: Error(s) in loading state_dict for SoundStream:
	Missing key(s) in state_dict: "encoder.0.conv.weight", "encoder.0.conv.bias", "encoder.1.0.fn.0.conv.weight", "encoder.1.0.fn.0.conv.bias", "encoder.1.0.fn.2.conv.weight", "encoder.1.0.fn.2.conv.bias", "encoder.1.1.fn.0.conv.weight", "encoder.1.1.fn.0.conv.bias", "encoder.1.1.fn.2.conv.weight", "encoder.1.1.fn.2.conv.bias", "encoder.1.2.fn.0.conv.weight", "encoder.1.2.fn.0.conv.bias", "encoder.1.2.fn.2.conv.weight", "encoder.1.2.fn.2.conv.bias", "encoder.1.3.conv.weight", "encoder.1.3.conv.bias", "encoder.2.0.fn.0.conv.weight", "encoder.2.0.fn.0.conv.bias", "encoder.2.0.fn.2.conv.weight", "encoder.2.0.fn.2.conv.bias", "encoder.2.1.fn.0.conv.weight", "encoder.2.1.fn.0.conv.bias", "encoder.2.1.fn.2.conv.weight", "encoder.2.1.fn.2.conv.bias", "encoder.2.2.fn.0.conv.weight", "encoder.2.2.fn.0.conv.bias", "encoder.2.2.fn.2.conv.weight", "encoder.2.2.fn.2.conv.bias", "encoder.2.3.conv.weight", "encoder.2.3.conv.bias", "encoder.3.0.fn.0.conv.weight", "encoder.3.0.fn.0.conv.bias", "encoder.3.0.fn.2.conv.weight", "encoder.3.0.fn.2.conv.bias", "encoder.3.1.fn.0.conv.weight", "encoder.3.1.fn.0.conv.bias", "encoder.3.1.fn.2.conv.weight", "encoder.3.1.fn.2.conv.bias", "encoder.3.2.fn.0.conv.weight", "encoder.3.2.fn.0.conv.bias", "encoder.3.2.fn.2.conv.weight", "encoder.3.2.fn.2.conv.bias", "encoder.3.3.conv.weight", "encoder.3.3.conv.bias", "encoder.4.0.fn.0.conv.weight", "encoder.4.0.fn.0.conv.bias", "encoder.4.0.fn.2.conv.weight", "encoder.4.0.fn.2.conv.bias", "encoder.4.1.fn.0.conv.weight", "encoder.4.1.fn.0.conv.bias", "encoder.4.1.fn.2.conv.weight", "encoder.4.1.fn.2.conv.bias", "encoder.4.2.fn.0.conv.weight", "encoder.4.2.fn.0.conv.bias", "encoder.4.2.fn.2.conv.weight", "encoder.4.2.fn.2.conv.bias", "encoder.4.3.conv.weight", "encoder.4.3.conv.bias", "encoder.5.conv.weight", "encoder.5.conv.bias", "encoder_attn.layers.0.0.q_scale", "encoder_attn.layers.0.0.k_scale", "encoder_attn.layers.0.0.norm.weight", "encoder_attn.layers.0.0.norm.bias", "encoder_attn.layers.0.0.to_qkv.weight", "encoder_attn.layers.0.0.attn_fn.rel_pos.inv_freq", "encoder_attn.layers.0.0.to_out.weight", "encoder_attn.layers.0.1.0.weight", "encoder_attn.layers.0.1.0.bias", "encoder_attn.layers.0.1.1.weight", "encoder_attn.layers.0.1.4.weight", "encoder_film.to_cond.weight", "encoder_film.to_cond.bias", "rq.rvqs.0.layers.0._codebook.initted", "rq.rvqs.0.layers.0._codebook.cluster_size", "rq.rvqs.0.layers.0._codebook.embed_avg", "rq.rvqs.0.layers.0._codebook.embed", "rq.rvqs.0.layers.1._codebook.initted", "rq.rvqs.0.layers.1._codebook.cluster_size", "rq.rvqs.0.layers.1._codebook.embed_avg", "rq.rvqs.0.layers.1._codebook.embed", "rq.rvqs.0.layers.2._codebook.initted", "rq.rvqs.0.layers.2._codebook.cluster_size", "rq.rvqs.0.layers.2._codebook.embed_avg", "rq.rvqs.0.layers.2._codebook.embed", "rq.rvqs.0.layers.3._codebook.initted", "rq.rvqs.0.layers.3._codebook.cluster_size", "rq.rvqs.0.layers.3._codebook.embed_avg", "rq.rvqs.0.layers.3._codebook.embed", "rq.rvqs.0.layers.4._codebook.initted", "rq.rvqs.0.layers.4._codebook.cluster_size", "rq.rvqs.0.layers.4._codebook.embed_avg", "rq.rvqs.0.layers.4._codebook.embed", "rq.rvqs.0.layers.5._codebook.initted", "rq.rvqs.0.layers.5._codebook.cluster_size", "rq.rvqs.0.layers.5._codebook.embed_avg", "rq.rvqs.0.layers.5._codebook.embed", "rq.rvqs.0.layers.6._codebook.initted", "rq.rvqs.0.layers.6._codebook.cluster_size", "rq.rvqs.0.layers.6._codebook.embed_avg", "rq.rvqs.0.layers.6._codebook.embed", "rq.rvqs.0.layers.7._codebook.initted", "rq.rvqs.0.layers.7._codebook.cluster_size", "rq.rvqs.0.layers.7._codebook.embed_avg", "rq.rvqs.0.layers.7._codebook.embed", "decoder_film.to_cond.weight", "decoder_film.to_cond.bias", "decoder_attn.layers.0.0.q_scale", "decoder_attn.layers.0.0.k_scale", "decoder_attn.layers.0.0.norm.weight", "decoder_attn.layers.0.0.norm.bias", "decoder_attn.layers.0.0.to_qkv.weight", "decoder_attn.layers.0.0.attn_fn.rel_pos.inv_freq", "decoder_attn.layers.0.0.to_out.weight", "decoder_attn.layers.0.1.0.weight", "decoder_attn.layers.0.1.0.bias", "decoder_attn.layers.0.1.1.weight", "decoder_attn.layers.0.1.4.weight", "decoder.0.conv.weight", "decoder.0.conv.bias", "decoder.1.0.conv.weight", "decoder.1.0.conv.bias", "decoder.1.1.fn.0.conv.weight", "decoder.1.1.fn.0.conv.bias", "decoder.1.1.fn.2.conv.weight", "decoder.1.1.fn.2.conv.bias", "decoder.1.2.fn.0.conv.weight", "decoder.1.2.fn.0.conv.bias", "decoder.1.2.fn.2.conv.weight", "decoder.1.2.fn.2.conv.bias", "decoder.1.3.fn.0.conv.weight", "decoder.1.3.fn.0.conv.bias", "decoder.1.3.fn.2.conv.weight", "decoder.1.3.fn.2.conv.bias", "decoder.2.0.conv.weight", "decoder.2.0.conv.bias", "decoder.2.1.fn.0.conv.weight", "decoder.2.1.fn.0.conv.bias", "decoder.2.1.fn.2.conv.weight", "decoder.2.1.fn.2.conv.bias", "decoder.2.2.fn.0.conv.weight", "decoder.2.2.fn.0.conv.bias", "decoder.2.2.fn.2.conv.weight", "decoder.2.2.fn.2.conv.bias", "decoder.2.3.fn.0.conv.weight", "decoder.2.3.fn.0.conv.bias", "decoder.2.3.fn.2.conv.weight", "decoder.2.3.fn.2.conv.bias", "decoder.3.0.conv.weight", "decoder.3.0.conv.bias", "decoder.3.1.fn.0.conv.weight", "decoder.3.1.fn.0.conv.bias", "decoder.3.1.fn.2.conv.weight", "decoder.3.1.fn.2.conv.bias", "decoder.3.2.fn.0.conv.weight", "decoder.3.2.fn.0.conv.bias", "decoder.3.2.fn.2.conv.weight", "decoder.3.2.fn.2.conv.bias", "decoder.3.3.fn.0.conv.weight", "decoder.3.3.fn.0.conv.bias", "decoder.3.3.fn.2.conv.weight", "decoder.3.3.fn.2.conv.bias", "decoder.4.0.conv.weight", "decoder.4.0.conv.bias", "decoder.4.1.fn.0.conv.weight", "decoder.4.1.fn.0.conv.bias", "decoder.4.1.fn.2.conv.weight", "decoder.4.1.fn.2.conv.bias", "decoder.4.2.fn.0.conv.weight", "decoder.4.2.fn.0.conv.bias", "decoder.4.2.fn.2.conv.weight", "decoder.4.2.fn.2.conv.bias", "decoder.4.3.fn.0.conv.weight", "decoder.4.3.fn.0.conv.bias", "decoder.4.3.fn.2.conv.weight", "decoder.4.3.fn.2.conv.bias", "decoder.5.conv.weight", "decoder.5.conv.bias", "discriminators.0.init_conv.weight", "discriminators.0.init_conv.bias", "discriminators.0.conv_layers.0.0.weight", "discriminators.0.conv_layers.0.0.bias", "discriminators.0.conv_layers.1.0.weight", "discriminators.0.conv_layers.1.0.bias", "discriminators.0.conv_layers.2.0.weight", "discriminators.0.conv_layers.2.0.bias", "discriminators.0.conv_layers.3.0.weight", "discriminators.0.conv_layers.3.0.bias", "discriminators.0.final_conv.0.weight", "discriminators.0.final_conv.0.bias", "discriminators.0.final_conv.2.weight", "discriminators.0.final_conv.2.bias", "discriminators.1.init_conv.weight", "discriminators.1.init_conv.bias", "discriminators.1.conv_layers.0.0.weight", "discriminators.1.conv_layers.0.0.bias", "discriminators.1.conv_layers.1.0.weight", "discriminators.1.conv_layers.1.0.bias", "discriminators.1.conv_layers.2.0.weight", "discriminators.1.conv_layers.2.0.bias", "discriminators.1.conv_layers.3.0.weight", "discriminators.1.conv_layers.3.0.bias", "discriminators.1.final_conv.0.weight", "discriminators.1.final_conv.0.bias", "discriminators.1.final_conv.2.weight", "discriminators.1.final_conv.2.bias", "discriminators.2.init_conv.weight", "discriminators.2.init_conv.bias", "discriminators.2.conv_layers.0.0.weight", "discriminators.2.conv_layers.0.0.bias", "discriminators.2.conv_layers.1.0.weight", "discriminators.2.conv_layers.1.0.bias", "discriminators.2.conv_layers.2.0.weight", "discriminators.2.conv_layers.2.0.bias", "discriminators.2.conv_layers.3.0.weight", "discriminators.2.conv_layers.3.0.bias", "discriminators.2.final_conv.0.weight", "discriminators.2.final_conv.0.bias", "discriminators.2.final_conv.2.weight", "discriminators.2.final_conv.2.bias", "stft_discriminator.init_conv.weight", "stft_discriminator.init_conv.bias", "stft_discriminator.layers.0.0.fn.0.weight", "stft_discriminator.layers.0.0.fn.0.bias", "stft_discriminator.layers.0.0.fn.1.b", "stft_discriminator.layers.0.0.fn.2.weight", "stft_discriminator.layers.0.0.fn.2.bias", "stft_discriminator.layers.0.1.weight", "stft_discriminator.layers.0.1.bias", "stft_discriminator.layers.1.0.fn.0.weight", "stft_discriminator.layers.1.0.fn.0.bias", "stft_discriminator.layers.1.0.fn.1.b", "stft_discriminator.layers.1.0.fn.2.weight", "stft_discriminator.layers.1.0.fn.2.bias", "stft_discriminator.layers.1.1.weight", "stft_discriminator.layers.1.1.bias", "stft_discriminator.layers.2.0.fn.0.weight", "stft_discriminator.layers.2.0.fn.0.bias", "stft_discriminator.layers.2.0.fn.1.b", "stft_discriminator.layers.2.0.fn.2.weight", "stft_discriminator.layers.2.0.fn.2.bias", "stft_discriminator.layers.2.1.weight", "stft_discriminator.layers.2.1.bias", "stft_discriminator.layers.3.0.fn.0.weight", "stft_discriminator.layers.3.0.fn.0.bias", "stft_discriminator.layers.3.0.fn.1.b", "stft_discriminator.layers.3.0.fn.2.weight", "stft_discriminator.layers.3.0.fn.2.bias", "stft_discriminator.layers.3.1.weight", "stft_discriminator.layers.3.1.bias", "stft_discriminator.layers.4.0.fn.0.weight", "stft_discriminator.layers.4.0.fn.0.bias", "stft_discriminator.layers.4.0.fn.1.b", "stft_discriminator.layers.4.0.fn.2.weight", "stft_discriminator.layers.4.0.fn.2.bias", "stft_discriminator.layers.4.1.weight", "stft_discriminator.layers.4.1.bias", "stft_discriminator.layers.5.0.fn.0.weight", "stft_discriminator.layers.5.0.fn.0.bias", "stft_discriminator.layers.5.0.fn.1.b", "stft_discriminator.layers.5.0.fn.2.weight", "stft_discriminator.layers.5.0.fn.2.bias", "stft_discriminator.layers.5.1.weight", "stft_discriminator.layers.5.1.bias", "stft_discriminator.final_conv.weight", "stft_discriminator.final_conv.bias", "mel_spec_transforms.0.spectrogram.window", "mel_spec_transforms.0.mel_scale.fb", "mel_spec_transforms.1.spectrogram.window", "mel_spec_transforms.1.mel_scale.fb", "mel_spec_transforms.2.spectrogram.window", "mel_spec_transforms.2.mel_scale.fb", "mel_spec_transforms.3.spectrogram.window", "mel_spec_transforms.3.mel_scale.fb", "mel_spec_transforms.4.spectrogram.window", "mel_spec_transforms.4.mel_scale.fb", "mel_spec_transforms.5.spectrogram.window", "mel_spec_transforms.5.mel_scale.fb". 
	Unexpected key(s) in state_dict: "start_token", "semantic_embedding.weight", "proj_text_embed.weight", "transformer.layers.0.0.norm.gamma", "transformer.layers.0.0.norm.beta", "transformer.layers.0.0.to_q.weight", "transformer.layers.0.0.to_kv.weight", "transformer.layers.0.0.to_out.0.weight", "transformer.layers.0.2.0.gamma", "transformer.layers.0.2.0.beta", "transformer.layers.0.2.1.weight", "transformer.layers.0.2.3.gamma", "transformer.layers.0.2.3.beta", "transformer.layers.0.2.5.weight", "transformer.layers.1.0.norm.gamma", "transformer.layers.1.0.norm.beta", "transformer.layers.1.0.to_q.weight", "transformer.layers.1.0.to_kv.weight", "transformer.layers.1.0.to_out.0.weight", "transformer.layers.1.2.0.gamma", "transformer.layers.1.2.0.beta", "transformer.layers.1.2.1.weight", "transformer.layers.1.2.3.gamma", "transformer.layers.1.2.3.beta", "transformer.layers.1.2.5.weight", "transformer.layers.2.0.norm.gamma", "transformer.layers.2.0.norm.beta", "transformer.layers.2.0.to_q.weight", "transformer.layers.2.0.to_kv.weight", "transformer.layers.2.0.to_out.0.weight", "transformer.layers.2.2.0.gamma", "transformer.layers.2.2.0.beta", "transformer.layers.2.2.1.weight", "transformer.layers.2.2.3.gamma", "transformer.layers.2.2.3.beta", "transformer.layers.2.2.5.weight", "transformer.layers.3.0.norm.gamma", "transformer.layers.3.0.norm.beta", "transformer.layers.3.0.to_q.weight", "transformer.layers.3.0.to_kv.weight", "transformer.layers.3.0.to_out.0.weight", "transformer.layers.3.2.0.gamma", "transformer.layers.3.2.0.beta", "transformer.layers.3.2.1.weight", "transformer.layers.3.2.3.gamma", "transformer.layers.3.2.3.beta", "transformer.layers.3.2.5.weight", "transformer.layers.4.0.norm.gamma", "transformer.layers.4.0.norm.beta", "transformer.layers.4.0.to_q.weight", "transformer.layers.4.0.to_kv.weight", "transformer.layers.4.0.to_out.0.weight", "transformer.layers.4.2.0.gamma", "transformer.layers.4.2.0.beta", "transformer.layers.4.2.1.weight", "transformer.layers.4.2.3.gamma", "transformer.layers.4.2.3.beta", "transformer.layers.4.2.5.weight", "transformer.layers.5.0.norm.gamma", "transformer.layers.5.0.norm.beta", "transformer.layers.5.0.to_q.weight", "transformer.layers.5.0.to_kv.weight", "transformer.layers.5.0.to_out.0.weight", "transformer.layers.5.2.0.gamma", "transformer.layers.5.2.0.beta", "transformer.layers.5.2.1.weight", "transformer.layers.5.2.3.gamma", "transformer.layers.5.2.3.beta", "transformer.layers.5.2.5.weight", "transformer.rel_pos_bias.net.0.0.weight", "transformer.rel_pos_bias.net.0.0.bias", "transformer.rel_pos_bias.net.1.0.weight", "transformer.rel_pos_bias.net.1.0.bias", "transformer.rel_pos_bias.net.2.0.weight", "transformer.rel_pos_bias.net.2.0.bias", "transformer.rel_pos_bias.net.3.weight", "transformer.rel_pos_bias.net.3.bias", "transformer.norm.gamma", "transformer.norm.beta", "to_logits.weight", "to_logits.bias". 

In [21]:
coarse_transformer = None
fine_transformer = None

# Everything together
audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

BeartypeCallHintParamViolation: Method audiolm_pytorch.audiolm_pytorch.AudioLM.__init__() parameter coarse_transformer="None" violates type hint <class 'audiolm_pytorch.audiolm_pytorch.CoarseTransformer'>, as <class "builtins.NoneType"> "None" not instance of <class "audiolm_pytorch.audiolm_pytorch.CoarseTransformer">.

In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)