In [1]:
import torch
from audiolm_pytorch import EncodecWrapper

  from .autonotebook import tqdm as notebook_tqdm
2024-01-08 15:21:34 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [2]:
audio = torch.randn(1, 512 * 320).cuda()
audio.shape

torch.Size([1, 163840])

In [3]:
from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream

soundstream = AudioLMSoundStream(
    codebook_size = 4096,
    rq_num_quantizers = 8,
    rq_groups = 2,                       # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
    use_lookup_free_quantizer = True,    # whether to use residual lookup free quantization - there are now reports of successful usage of this unpublished technique
    use_finite_scalar_quantizer = False, # whether to use residual finite scalar quantization
    attn_window_size = 128,              # local attention receptive field at bottleneck
    attn_depth = 2                       # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
).cuda()
codes = soundstream.tokenize(audio)
codes.shape

torch.Size([2, 1, 512, 8])

In [4]:
encodec = EncodecWrapper().cuda()
encodec.eval()
codes = encodec(audio) # (1, 10080) - 1 channel
codes[1].shape

torch.Size([1, 512, 8])

In [5]:
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

wav2vec = HubertWithKmeans(
    checkpoint_path = '../NASFolder/pretrained/hubert/hubert_base_ls960.pt',
    kmeans_path = '../NASFolder/pretrained/hubert/hubert_base_ls960_L9_km500.bin'
)


In [6]:
semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cuda()

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


In [7]:
trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder ='../NASFolder/data/IEMOCAP',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

# trainer.train()

training with dataset of 9680 samples and validating with randomly splitted 510 samples


In [None]:
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = encodec,
    wav2vec = wav2vec,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

# trainer.train()

In [None]:
import torch
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = encodec,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

# trainer.train()

In [None]:
from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

# or with priming

generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

# or with text condition, if given

generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])