# AudioLM

### Libraries:

In [1]:
import torch
from audiolm_pytorch import HubertWithKmeans
from audiolm_pytorch import SemanticTransformer
from audiolm_pytorch import CoarseTransformer
from audiolm_pytorch import FineTransformer
from audiolm_pytorch import AudioLMSoundStream, AudioLM
from musiclm_pytorch import MuLaNEmbedQuantizer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

In [2]:
checkpoint_path = './models/hubert/hubert_base_ls960.pt'
kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'

In [3]:
wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)

soundstream = AudioLMSoundStream(
    codebook_size=1024,  # Add this line to specify the codebook size
    strides=(2, 4, 5, 8),
    target_sample_hz=16000,
    rq_num_quantizers=8
)


if torch.cuda.is_available():
    semantic_transformer = SemanticTransformer(
      num_semantic_tokens=wav2vec.codebook_size,
      dim=1024,
      depth=6,
      audio_text_condition=True
    ).cuda()

    coarse_transformer = CoarseTransformer(
      num_semantic_tokens=wav2vec.codebook_size,
      codebook_size=1024,
      num_coarse_quantizers=4,  # Consistent with training
      dim=1024,
      depth=6,
      audio_text_condition=True
    ).cuda()

    fine_transformer = FineTransformer(
      num_coarse_quantizers=4,  # Consistent with training
      num_fine_quantizers=4,
      codebook_size=1024,
      dim=1024,
      depth=6,
      audio_text_condition=True
    ).cuda()
else:
    semantic_transformer = SemanticTransformer(
      num_semantic_tokens=wav2vec.codebook_size,
      dim=1024,
      depth=6,
      audio_text_condition=True
    )

    coarse_transformer = CoarseTransformer(
      num_semantic_tokens=wav2vec.codebook_size,
      codebook_size=1024,
      num_coarse_quantizers=4,  # Consistent with training
      dim=1024,
      depth=6,
      audio_text_condition=True
    )

    fine_transformer = FineTransformer(
      num_coarse_quantizers=4,  # Consistent with training
      num_fine_quantizers=4,
      codebook_size=1024,
      dim=1024,
      depth=6,
      audio_text_condition=True
    )

semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))
coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))
fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))

audiolm = AudioLM(
    wav2vec=wav2vec,
    codec=soundstream,
    semantic_transformer=semantic_transformer,
    coarse_transformer=coarse_transformer,
    fine_transformer=fine_transformer
)


# MuLaN

In [4]:
audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                         
    conditioning_dims = (1024, 1024, 1024), 
    namespaces = ('semantic', 'coarse', 'fine')
)


# MusicLM

In [5]:
from musiclm_pytorch import MusicLM

if torch.cuda.is_available():
    musiclm = MusicLM(
    audio_lm = audiolm,
    mulan_embed_quantizer = quantizer
    ).cuda()
else:
  musiclm = MusicLM(
    audio_lm = audiolm,
    mulan_embed_quantizer = quantizer
    )

# Inference:

In [6]:
music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4)

generating semantic:   0%|          | 2/2048 [00:00<09:50,  3.47it/s]
generating coarse: 100%|██████████| 512/512 [39:28<00:00,  4.63s/it]
generating fine:   0%|          | 0/512 [00:10<?, ?it/s]


KeyboardInterrupt: 

In [None]:
torch.save(music, 'generated_music.pt')

In [None]:
import torchaudio
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, music.cpu() , sample_rate)