In [None]:
!pip install torchaudio
!pip install audiocraft
!pip install torch
!pip install wandb
!pip install ray

In [1]:
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout
from transformers import get_scheduler
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
import audiocraft
import ray
from ray import train, tune
import wandb
import random
import os

# Set up Ray

In [2]:
NUM_WORKERS = 16
BATCH_SIZE_PER_WORKER = 4
MODEL_NAME = 'facebook/musicgen-small'
ray.shutdown()
ray.init(
    num_cpus = 2,
    num_gpus = 1,
)

2023-11-26 13:40:02,741	INFO worker.py:1673 -- Started a local Ray instance.


0,1
Python version:,3.11.6
Ray version:,2.8.0


In [3]:
class AudioDataset(Dataset):
    def __init__(self, data_dir, no_label=False):
        self.data_dir = data_dir
        self.data_map = []

        dir_map = os.listdir(data_dir)
        for d in dir_map:
            name, ext = os.path.splitext(d)
            if ext == ".wav" or ext == ".mp3":
                if no_label:
                    self.data_map.append({"audio": os.path.join(data_dir, d)})
                    continue
                if os.path.exists(os.path.join(data_dir, name + ".txt")):
                    self.data_map.append(
                        {
                            "audio": os.path.join(data_dir, d),
                            "label": os.path.join(data_dir, name + ".txt"),
                        }
                    )
                else:
                    raise ValueError(f"No label file for {name}")

    def __len__(self):
        return len(self.data_map)

    def __getitem__(self, idx):
        data = self.data_map[idx]
        audio = data["audio"]
        label = data.get("label", "")

        return audio, label


def count_nans(tensor):
    nan_mask = torch.isnan(tensor)
    num_nans = torch.sum(nan_mask).item()
    return num_nans


def preprocess_audio(audio_path, model: MusicGen, duration: int = 10):
    """
    Preprocesses the audio file located at `audio_path` using the specified `model`.

    Args:
        audio_path (str): The path to the audio file.
        model (MusicGen): The music generation model.
        duration (int, optional): The desired duration of the preprocessed audio in seconds. Defaults to 30.

    Returns:
        torch.Tensor: The preprocessed audio codes.

    Raises:
        AssertionError: If the shape of the preprocessed audio is not as expected.
    """
    # Load the audio file
    wav, sr = torchaudio.load(audio_path)

    # Resample the audio to match the sample rate of the model
    wav = torchaudio.functional.resample(wav, sr, model.sample_rate)

    # Take the mean of the audio to convert stereo to mono
    wav = wav.mean(dim=0, keepdim=True)

    # Randomly select a segment of the audio with the desired duration
    end_sample = int(model.sample_rate * duration)
    start_sample = random.randrange(0, max(wav.shape[1] - end_sample, 1))
    wav = wav[:, start_sample: start_sample + end_sample]

    # Ensure that the audio has a single channel
    assert wav.shape[0] == 1

    # Move the audio to the GPU
    wav = wav.cuda()

    # Add a channel dimension to the audio
    wav = wav.unsqueeze(1)

    # Encode the audio using the compression model
    with torch.no_grad():
        gen_audio = model.compression_model.encode(wav)

    # Separate the audio codes and scale
    codes, scale = gen_audio

    # Ensure that the scale is None
    assert scale is None

    return codes


def fixnan(tensor: torch.Tensor):
    nan_mask = torch.isnan(tensor)
    result = torch.where(nan_mask, torch.zeros_like(tensor), tensor)

    return result


def one_hot_encode(tensor, num_classes=2048):
    """
    One-hot encodes a tensor.

    Args:
        tensor (torch.Tensor): The input tensor to be one-hot encoded.
        num_classes (int): The number of classes for one-hot encoding.

    Returns:
        torch.Tensor: The one-hot encoded tensor.
    """
    shape = tensor.shape
    one_hot = torch.zeros((shape[0], shape[1], num_classes))

    for i in range(shape[0]):
        for j in range(shape[1]):
            index = tensor[i, j].item()
            one_hot[i, j, index] = 1

    return one_hot

In [10]:
dataset_path = 'audio/'
model_id = 'facebook/musicgen-large'
lr = 1e-5
epochs = 10
use_wandb = 0
no_label = False
tune_text = False
save_step = None
grad_acc = 8
use_scaler = True
weight_decay = 1e-5
warmup_steps = 10
batch_size = 16
use_cfg = False


# Set the device to CUDA if available, otherwise to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if use_wandb:
  run = wandb.init(project="audiocraft")

# Load the pretrained model based on the given model_id
model = MusicGen.get_pretrained(model_id)
model.lm.load_state_dict(torch.load('models/lm_final.pt'))
for i, param in enumerate(model.lm.parameters()):
  if i >= 250:
    break
  param.requires_grad = False
# Convert the model's lm attribute to torch.float32
model.lm = model.lm.to(torch.float32)

# Create an instance of the AudioDataset using the dataset_path and no_label flag
dataset = AudioDataset(dataset_path, no_label=no_label)

# Create a DataLoader for the training dataset with the specified batch size and shuffle option
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Set the learning rate to lr
learning_rate = lr

# Set the model's lm attribute to training mode
model.lm.train()

# Create an instance of GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Print the tuning mode based on the tune_text flag
if tune_text:
  print("Tuning text")
else:
  print("Tuning everything")

# Create an instance of AdamW optimizer with the appropriate parameters based on the tune_text flag
optimizer = AdamW(
  model.lm.condition_provider.parameters() if tune_text else model.lm.parameters(),
  lr=learning_rate,
  betas=(0.9, 0.95),
  weight_decay=weight_decay,
)

# Create a scheduler using the cosine annealing method
scheduler = get_scheduler(
  "cosine",
  optimizer,
  warmup_steps,
  int(epochs * len(train_dataloader) / grad_acc),
)

# Create an instance of CrossEntropyLoss as the criterion for training
criterion = nn.CrossEntropyLoss()

# Set the device to CUDA if available, otherwise to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the number of epochs to train
num_epochs = epochs

# Set the save_step and save_models flags
save_step = save_step
save_models = False if save_step is None else True

# Set the save_path for saving the trained models
save_path = "models/"

# Create the save_path directory if it doesn't exist
os.makedirs(save_path, exist_ok=True)

# Initialize the current_step counter
current_step = 0



Tuning everything


In [None]:
for epoch in range(num_epochs):
    torch.save(model.lm.state_dict(), f"{save_path}/lm_final.pt")
    for batch_idx, (audio, label) in enumerate(train_dataloader):
      try:
              optimizer.zero_grad()

              all_codes = []
              texts = []

              # where audio and label are just paths
              for inner_audio, l in zip(audio, label):
                  inner_audio = preprocess_audio(
                      inner_audio, model)  # returns tensor
                  if inner_audio is None:
                      continue
                  if use_cfg:
                      codes = torch.cat([inner_audio, inner_audio], dim=0)
                  else:
                      codes = inner_audio

                  all_codes.append(codes)
                  texts.append(open(l, "r").read().strip())

              attributes, _ = model._prepare_tokens_and_attributes(
                  texts, None)
              conditions = attributes
              if use_cfg:
                  null_conditions = ClassifierFreeGuidanceDropout(
                      p=1.0)(conditions)
                  conditions = conditions + null_conditions
              tokenized = model.lm.condition_provider.tokenize(conditions)
              cfg_conditions = model.lm.condition_provider(tokenized)
              condition_tensors = cfg_conditions

              if len(all_codes) == 0:
                  continue

              codes = torch.cat(all_codes, dim=0)

              with torch.autocast(device_type="cuda", dtype=torch.float16):
                  lm_output = model.lm.compute_predictions(
                      codes=codes, conditions=[], condition_tensors=condition_tensors
                  )

                  logits = lm_output.logits
                  mask = lm_output.mask
                  reshaped_tensor = codes.view(-1, 500)

                  one_hot_encode = torch.nn.functional.one_hot(
                      reshaped_tensor, num_classes=2048)
                  codes_preprocess = one_hot_encode.view(
                      batch_size, logits.shape[1], logits.shape[2], logits.shape[3])

                  codes_preprocess = codes_preprocess.cuda()
                  logits = logits.cuda()
                  mask = mask.cuda()

                  mask = mask.reshape(-1)
                  masked_logits = logits.reshape(-1, 2048)[mask]
                  masked_codes = codes_preprocess.reshape(
                      -1, 2048)[mask].to(torch.float16)

                  loss = criterion(masked_logits, masked_codes)

              current_step += 1 / grad_acc

              # assert count_nans(masked_logits) == 0

              (scaler.scale(loss) if use_scaler else loss).backward()

              total_norm = 0
              for p in model.lm.condition_provider.parameters():
                  try:
                      param_norm = p.grad.data.norm(2)
                      total_norm += param_norm.item() ** 2
                  except AttributeError:
                      pass
              total_norm = total_norm ** (1.0 / 2)

              if use_wandb:
                  run.log(
                      {
                          "loss": loss.item(),
                          "total_norm": total_norm,
                      }
                  )

              print(
                  f"Epoch: {epoch}/{num_epochs}, Batch: {batch_idx}/{len(train_dataloader)}, Loss: {loss.item()}"
              )

              del inner_audio, attributes, conditions, condition_tensors, codes, logits, mask, lm_output, one_hot_encode, reshaped_tensor, codes_preprocess, masked_logits, masked_codes

              if batch_idx % grad_acc != grad_acc - 1:
                  continue

              if use_scaler:
                  scaler.unscale_(optimizer)
              torch.nn.utils.clip_grad_norm_(model.lm.parameters(), 0.5)

              if use_scaler:
                  scaler.step(optimizer)
                  scaler.update()
              else:
                  optimizer.step()
              scheduler.step()

              if save_models:
                  if (
                      current_step == int(current_step)
                      and int(current_step) % save_step == 0
                  ):
                      torch.save(
                          model.lm.state_dict(
                          ), f"{save_path}/lm_{current_step}.pt"
                      )

      except KeyboardInterrupt as e:
        print(e)
        break
      except RuntimeError as e:
        print(e)
        continue
      except:
        continue


  if x.storage().data_ptr() != tensors[0].storage().data_ptr():


Epoch: 0/10, Batch: 0/625, Loss: 3.3584885597229004
Epoch: 0/10, Batch: 1/625, Loss: 3.7058818340301514
Epoch: 0/10, Batch: 2/625, Loss: 3.584254264831543
Epoch: 0/10, Batch: 3/625, Loss: 3.553462266921997
Epoch: 0/10, Batch: 4/625, Loss: 3.4933247566223145
Epoch: 0/10, Batch: 5/625, Loss: 3.408195734024048
Epoch: 0/10, Batch: 6/625, Loss: 3.426941394805908
Epoch: 0/10, Batch: 7/625, Loss: 3.885221481323242
Epoch: 0/10, Batch: 8/625, Loss: 3.775754451751709
Epoch: 0/10, Batch: 9/625, Loss: 3.5410821437835693
Epoch: 0/10, Batch: 10/625, Loss: 3.7751624584198
Epoch: 0/10, Batch: 11/625, Loss: 3.5388429164886475
Epoch: 0/10, Batch: 12/625, Loss: 3.3925137519836426
Epoch: 0/10, Batch: 13/625, Loss: 3.7414729595184326
Epoch: 0/10, Batch: 14/625, Loss: 3.5923209190368652
Epoch: 0/10, Batch: 15/625, Loss: 3.7700905799865723
Epoch: 0/10, Batch: 16/625, Loss: 3.459801435470581
Epoch: 0/10, Batch: 17/625, Loss: 3.6061666011810303
Epoch: 0/10, Batch: 18/625, Loss: 3.5898215770721436
Epoch: 0/10, 

In [None]:
from IPython.display import Audio
Audio('/content/audio/1699168569.6232119.mp3', autoplay=True, rate=16000)

In [None]:
results = tune.Tuner(
    tune.with_resources(trainable_thing, resources={'cpu':8, 'gpu':2}),
    tune_config = tune.TuneConfig(num_samples=1)
)
results.fit()

In [None]:
 logits = lm_output.logits
mask = lm_output.mask
reshaped_tensor = codes.view(-1, 500)

one_hot_encode = torch.nn.functional.one_hot(
reshaped_tensor, num_classes=2048)
codes_preprocess = one_hot_encode.view(
batch_size, logits.shape[1], logits.shape[2], logits.shape[3])

codes_preprocess = codes_preprocess.cuda()
logits = logits.cuda()
mask = mask.cuda()

mask = mask.reshape(-1)
masked_logits = logits.reshape(-1, 2048)[mask]
masked_codes = codes_preprocess.reshape(
-1, 2048)[mask].to(torch.float16)

In [None]:
results = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(train_muze),
        resources={'cpu': 2, 'gpu': 1}
    ),
    param_space={}
)

In [None]:
results.fit()

In [4]:
!wget https://storage.googleapis.com/muze-data/test/public.json

--2023-11-23 12:26:04--  https://storage.googleapis.com/muze-data/test/public.json
Resolving storage.googleapis.com (storage.googleapis.com)... 192.178.48.251, 192.178.49.27, 192.178.49.219, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|192.178.48.251|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 258706 (253K) [application/json]
Saving to: ‘public.json’


2023-11-23 12:26:04 (5.93 MB/s) - ‘public.json’ saved [258706/258706]



# Generating

In [None]:
description = 'The recording features a traditional song that consists of a wooden percussion, claps, shimmering shakers, groovy bass, sustained strings melody and synth keys melody. It sounds happy, fun and joyful'

In [7]:
from IPython.display import Audio
Audio('generated/1699168495.915989.wav', autoplay=True)

In [4]:
import torchaudio
from tqdm import trange
from audiocraft.models import MusicGen
import json
import torch

model = MusicGen.get_pretrained("facebook/musicgen-large")
model.lm.load_state_dict(torch.load('models/lm_final.pt'))
self = model
save_path = 'generated/'
with open("public.json", "r") as f:
    data = json.load(f)
description_list = []
file_name_list = []
for file_name, description in data.items():
    description_list.append(description)
    file_name_list.append(file_name)
    if len(description_list) < 50:
        continue
    attributes, prompt_tokens = self._prepare_tokens_and_attributes(description_list, None)
    print("attributes:", attributes)
    print("prompt_tokens:", prompt_tokens)
    
    duration = 10
    self.generation_params = {
        'max_gen_len': int(duration * self.frame_rate),
        'use_sampling': 1,
        'temp': 1.0,
        'top_k': 250,
        'top_p': 0.0,
        'cfg_coef': 3.0,
        'two_step_cfg': 0,
    }
    total = []

    with self.autocast:
        gen_tokens = self.lm.generate(
            prompt_tokens, attributes, callback=None, **self.generation_params)
        total.append(gen_tokens[..., prompt_tokens.shape[-1]
                                if prompt_tokens is not None else 0:])
        prompt_tokens = gen_tokens[..., -gen_tokens.shape[-1] // 2:]
    gen_tokens = torch.cat(total, -1)
    
    with torch.no_grad():
        self.compression_model.sample_rate = 16000
        gen_audio = self.compression_model.decode(gen_tokens, None)
        gen_audio = gen_audio.cpu()
        for i in range(len(description_list)):
            print(save_path + (file_name_list[i]))
            torchaudio.save(save_path + (file_name_list[i]), gen_audio[i,:, :160000], self.sample_rate)

    description_list = []
    file_name_list = []
    break





attributes: [ConditioningAttributes(text={'description': 'The recording features a widely spread electric piano melody, followed by synth pad chords. It sounds emotional and passionate.'}, wav={'self_wav': WavCondition(wav=tensor([[[0.]]], device='cuda:0'), length=tensor([0], device='cuda:0'), sample_rate=[32000], path=[None], seek_time=[])}, joint_embed={}), ConditioningAttributes(text={'description': 'The recording features a cover of a rock song and it consists of an arpeggiated acoustic guitar melody. It sounds groovy and the recording is noisy and in mono.'}, wav={'self_wav': WavCondition(wav=tensor([[[0.]]], device='cuda:0'), length=tensor([0], device='cuda:0'), sample_rate=[32000], path=[None], seek_time=[])}, joint_embed={}), ConditioningAttributes(text={'description': 'The recording features an arpeggiated acoustic guitar melody played over playback. The recording is noisy and in mono, as it was probably recorded with a phone.'}, wav={'self_wav': WavCondition(wav=tensor([[[0.]

In [8]:
!cat public.json | grep '1699168495.217152'

  "1699168495.217152.mp3": "The recording features a cover of a rock song and it consists of an arpeggiated acoustic guitar melody. It sounds groovy and the recording is noisy and in mono.",


In [9]:
from IPython.display import Audio
Audio('generated/1699168495.217152.mp3', autoplay=True)

In [6]:
!zip submission.zip -r generated/

updating: generated/ (stored 0%)
updating: generated/1699168497.7077665.mp3 (deflated 3%)
updating: generated/1699168497.309414.mp3 (deflated 3%)
updating: generated/1699168497.7584674.mp3 (deflated 3%)
updating: generated/1699168496.6261263.mp3 (deflated 2%)
updating: generated/1699168497.4566393.mp3 (deflated 2%)
updating: generated/1699168497.4407313.mp3 (deflated 2%)
updating: generated/1699168497.69898.mp3 (deflated 3%)
updating: generated/1699168495.9059513.mp3 (deflated 3%)
updating: generated/1699168497.6694798.mp3 (deflated 2%)
updating: generated/1699168495.9825058.mp3 (deflated 3%)
updating: generated/1699168498.9220755.mp3 (deflated 3%)
updating: generated/1699168496.5416749.mp3 (deflated 3%)
updating: generated/1699168498.3100872.mp3 (deflated 2%)
updating: generated/1699168497.9484036.mp3 (deflated 3%)
updating: generated/1699168497.229152.mp3 (deflated 3%)
updating: generated/1699168497.3363004.mp3 (deflated 3%)
updating: generated/1699168495.439177.mp3 (deflated 3%)
upd