<h1>MusicGen </h1>

<h2>Install Dependencies</h2>

In [1]:
%pip install lightning audiocraft torch numpy pandas torchaudio librosa matplotlib scipy
!sudo apt install ffmpeg
!git clone https://github.com/facebookresearch/audiocraft.git

Collecting torch
  Using cached torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl.metadata (25 kB)
Collecting torchaudio
  Using cached torchaudio-2.1.1-cp310-cp310-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)


<h2>Imports and Dependencies</h2>

In [1]:
# Import necessary libraries for data handling, audio processing, and deep learning.
import os  # For interacting with the operating system (e.g., file paths).
import numpy as np
import pandas as pd  # For data manipulation and analysis.
import torch  # PyTorch: A machine learning library for building models.
import torch.nn as nn  # Neural network module in PyTorch.
import torch.nn.functional as F  # Functional interface for PyTorch operations.
import torchaudio  # PyTorch's audio processing library.
import lightning.pytorch as L  # PyTorch Lightning: Simplifies PyTorch training.
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping  # Callbacks for saving models and stopping early.
from torch.optim.lr_scheduler import ReduceLROnPlateau  # Learning rate scheduler to reduce learning rate on plateau.
from audiocraft.models.musicgen import MusicGen  # Pretrained music generation model from Meta's AudioCraft.
from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout  # For classifier-free guidance.
from sklearn.model_selection import train_test_split  # To split the dataset into training and validation sets.
from audiocraft.data.audio import audio_write
from scipy.io import wavfile


<h2>Audio Dataset Class</h2>

In [8]:
# Define a custom dataset class for loading audio and captions from a CSV file.
class CSVAudioDataset(torch.utils.data.Dataset):
    def __init__(self, df, audio_root, segment_duration=30, sample_rate=44100):
        # Initialize the dataset with a DataFrame, root directory for audio files, segment duration, and sample rate.
        self.df = df.reset_index(drop=True)  # Reset index of the DataFrame for consistent indexing.
        self.audio_root = audio_root  # Root directory where audio files are stored.
        self.segment_duration = segment_duration  # Duration of each audio segment in seconds.
        self.sample_rate = sample_rate  # Sample rate for audio processing.

    def __len__(self):
        # Return the total number of samples in the dataset.
        return len(self.df)

    def __getitem__(self, idx):
        # Get a single item from the dataset based on the given index.
        row = self.df.iloc[idx]  # Retrieve the row corresponding to the index.
        ytid = row['ytid']  # Extract the YouTube ID from the row.
        audio_path = os.path.join(self.audio_root, f"{ytid}.wav")  # Construct the path to the audio file.
        
        waveform, sr = torchaudio.load(audio_path)  # Load the audio waveform and its sample rate.
        if sr != self.sample_rate:
            # Resample the audio if the sample rate does not match the desired rate.
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
        
        if waveform.shape[0] > 1:
            # If the audio has more than one channel, keep only the first channel.
            waveform = waveform[0:1, :]
        
        num_samples = int(self.segment_duration * self.sample_rate)  # Calculate the number of samples for the segment.
        if waveform.shape[1] > num_samples:
            # If the waveform is longer than the segment duration, truncate it.
            waveform = waveform[:, :num_samples]
        else:
            # If the waveform is shorter, pad it with zeros to match the segment duration.
            waveform = F.pad(waveform, (0, num_samples - waveform.shape[1]))
        
        caption = row['caption']  # Extract the caption associated with the audio.
        return waveform, caption  # Return the processed waveform and caption.

<h2>Helper Function for Condition Tensor</h2>

In [9]:
# Function to generate condition tensors for classifier-free guidance.
def get_condition_tensor(model, attributes):
    null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(attributes)  # Generate null conditions for classifier-free guidance.
    conditions = attributes + null_conditions  # Combine original and null conditions.
    tokenized = model.lm.condition_provider.tokenize(conditions)  # Tokenize the combined conditions.
    cfg_conditions = model.lm.condition_provider(tokenized)  # Generate condition tensors from tokenized conditions.
    return cfg_conditions  # Return the condition tensors.

<h2>Lightning Module for MusicGen Fine-Tuning</h2>

In [None]:
# Define the LightningModule for fine-tuning the MusicGen model.
class MusicGenFinetuning(L.LightningModule):
    def __init__(self):
        super().__init__()  # Initialize the parent class.
        self.model = MusicGen.get_pretrained("facebook/musicgen-Large")  # Load the pretrained MusicGen model.
        self.model.lm.train()  # Set the language model part of MusicGen to training mode.
        self.model.lm = self.model.lm.float()  # Ensure the model uses float32 precision.

    def training_step(self, batch, batch_idx):
        # Perform a single training step.
        waveforms, captions = batch  # Unpack the batch into waveforms and captions.
        waveforms = waveforms.to(self.device)  # Move waveforms to the appropriate device (GPU/CPU).
        batch_size = waveforms.size(0)  # Get the batch size.
        
        with torch.no_grad():
            # Encode the waveforms into discrete codes using the compression model.
            codes, _ = self.model.compression_model.encode(waveforms)
        
        codes = codes.repeat_interleave(2, dim=0)  # Duplicate the codes for classifier-free guidance.
        attributes, _ = self.model._prepare_tokens_and_attributes(captions, None)  # Prepare attributes from captions.
        condition_tensors = get_condition_tensor(self.model, attributes)  # Generate condition tensors.
        

        lm_output = self.model.lm.compute_predictions(
            codes=codes, 
            conditions=[], 
            condition_tensors=condition_tensors
        )  # Compute predictions using the language model.
        
        logits = lm_output.logits  # Extract the logits from the model output.
        mask = lm_output.mask  # Extract the mask indicating valid tokens.
        logits_flat = logits[mask].view(-1, logits.size(-1))  # Flatten the logits for valid tokens.
        targets_flat = codes[mask].view(-1)  # Flatten the target codes for valid tokens.
        
        loss = F.cross_entropy(logits_flat, targets_flat)  # Compute the cross-entropy loss.
        accuracy = (logits_flat.argmax(-1) == targets_flat).float().mean()
        
        # Log training metrics.
        self.log("train_loss", loss, prog_bar=True, batch_size=batch_size)
        self.log("train_accuracy", accuracy, prog_bar=True, batch_size=batch_size)
        return loss  # Return the loss for optimization.

    def validation_step(self, batch, batch_idx):
        # Perform a single validation step.
        waveforms, captions = batch  # Unpack the batch into waveforms and captions.
        waveforms = waveforms.to(self.device)  # Move waveforms to the appropriate device.
        batch_size = waveforms.size(0)  # Get the batch size.
        
        with torch.no_grad():
            # Encode the waveforms into discrete codes using the compression model.
            codes, _ = self.model.compression_model.encode(waveforms)
        
        codes = codes.repeat_interleave(2, dim=0)  # Duplicate the codes for classifier-free guidance.
        attributes, _ = self.model._prepare_tokens_and_attributes(captions, None)  # Prepare attributes from captions.
        condition_tensors = get_condition_tensor(self.model, attributes)  # Generate condition tensors.
        
        lm_output = self.model.lm.compute_predictions(
            codes=codes, 
            conditions=[], 
            condition_tensors=condition_tensors
        )  # Compute predictions using the language model.
        
        logits = lm_output.logits  # Extract the logits from the model output.
        mask = lm_output.mask  # Extract the mask indicating valid tokens.
        logits_flat = logits[mask].view(-1, logits.size(-1))  # Flatten the logits for valid tokens.
        targets_flat = codes[mask].view(-1)  # Flatten the target codes for valid tokens.
        
        val_loss = F.cross_entropy(logits_flat, targets_flat)  # Compute the validation loss.
        val_acc = (logits_flat.argmax(-1) == targets_flat).float().mean()  # Compute validation accuracy.
       
        # Log validation metrics.
        self.log("val_loss", val_loss, prog_bar=True, batch_size=batch_size)
        self.log("val_accuracy", val_acc, prog_bar=True, batch_size=batch_size)
        return val_loss  # Return the validation loss.

    def configure_optimizers(self):
        # Configure the optimizer and learning rate scheduler.
        optimizer = torch.optim.AdamW(self.model.lm.parameters(), lr=1e-5)  # Use AdamW optimizer with a learning rate of 1e-5.
        scheduler = ReduceLROnPlateau(  # Use a learning rate scheduler that reduces the learning rate on a plateau.
            optimizer,
            mode="min",  # Monitor the metric to minimize (validation loss).
            factor=0.1,  # Reduce the learning rate by a factor of 0.1.
            patience=2,  # Wait for 2 epochs before reducing the learning rate.
            verbose=True,  # Print messages when the learning rate is reduced.
            min_lr=1e-6  # Set the minimum learning rate.
        )
        return {
            "optimizer": optimizer,  # Return the optimizer.
            "lr_scheduler": {  # Return the learning rate scheduler configuration.
                "scheduler": scheduler,  # The scheduler object.
                "monitor": "val_loss",  # Monitor the validation loss.
                "interval": "epoch",  # Update the scheduler every epoch.
                "frequency": 1,  # Frequency of updates.
            },
        }

<h2>Data Preparation and Dataloaders</h2>

In [13]:
if __name__ == "__main__":
    L.seed_everything(42)  # Set a random seed for reproducibility.
    
    full_df = pd.read_csv("metadata_updated.csv")  # Load the metadata CSV file.
    train_df, val_df = train_test_split(full_df, test_size=0.2, random_state=42)  # Split the dataset into training and validation sets.
    
    train_dataset = CSVAudioDataset(  # Create the training dataset.
        train_df, 
        audio_root="/teamspace/studios/this_studio/music_data"
    )
    val_dataset = CSVAudioDataset(  # Create the validation dataset.
        val_df, 
        audio_root="/teamspace/studios/this_studio/music_data"
    )
    
    train_dataloader = torch.utils.data.DataLoader(  # Create the training dataloader.
        train_dataset,
        batch_size=2,  # Batch size of 2.
        shuffle=True,  # Shuffle the data for each epoch.
        num_workers=2,  # Use 2 worker threads for data loading.
        collate_fn=lambda x: (  # Custom collate function to handle batches.
            torch.stack([sample[0] for sample in x]),
            [sample[1] for sample in x]
        )
    )
    
    val_dataloader = torch.utils.data.DataLoader(  # Create the validation dataloader.
        val_dataset,
        batch_size=2,  # Batch size of 2.
        shuffle=False,  # Do not shuffle the validation data.
        num_workers=2,  # Use 2 worker threads for data loading.
        collate_fn=lambda x: (  # Custom collate function to handle batches.
            torch.stack([sample[0] for sample in x]),
            [sample[1] for sample in x]
        )
    )

Seed set to 42


<h2>Training Configuration and Execution</h2>

In [15]:
model = MusicGenFinetuning()  # Instantiate the fine-tuning model.
trainer = L.Trainer(  # Create the trainer.
    precision="16-mixed",  # Use mixed-precision training (FP16).
    max_epochs=10,  # Train for a maximum of 10 epochs.
    val_check_interval=0.5,  # Check validation metrics every 50% of an epoch.
    callbacks=[  # Add callbacks for checkpointing and early stopping.
        ModelCheckpoint(monitor="val_loss"),  # Save the best model based on validation loss.
        EarlyStopping(monitor="val_loss", patience=3)  # Stop training if validation loss does not improve for 3 epochs.
    ]
    )
trainer.fit(model, train_dataloader, val_dataloader)  # Start the training process.

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params | Mode
-------------------------------------
-------------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
0         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

<h2>Utility Functions</h2>

In [4]:
checkpoint = torch.load("lightning_logs/version_7/checkpoints/epoch=0-step=1490.ckpt")
state_dict = checkpoint["state_dict"]

model = MusicGen.get_pretrained("facebook/musicgen-Large")
model.lm.load_state_dict(state_dict,strict=False)
model.lm.float()
model.lm.eval()



LMModel(
  (cfg_dropout): ClassifierFreeGuidanceDropout(p=0.3)
  (att_dropout): AttributeDropout({})
  (condition_provider): ConditioningProvider(
    (conditioners): ModuleDict(
      (description): T5Conditioner(
        (output_proj): Linear(in_features=768, out_features=1024, bias=True)
      )
    )
  )
  (fuser): ConditionFuser()
  (emb): ModuleList(
    (0-3): 4 x ScaledEmbedding(2049, 1024)
  )
  (transformer): StreamingTransformer(
    (layers): ModuleList(
      (0-23): 24 x StreamingTransformerLayer(
        (self_attn): StreamingMultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (linear1): Linear(in_features=1024, out_features=4096, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=4096, out_features=1024, bias=False)
        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  

<h2>UnConditional Generation</h2>

In [15]:
@torch.no_grad()
def generate(model, num_samples, total_gen_len, temp=1.0, top_k=250):
    model.set_generation_params(use_sampling=True, top_k=top_k, duration=30)

    samples = []
    prev_generated = None
    descriptions = [None]

    attributes, prompt_tokens = model._prepare_tokens_and_attributes(descriptions, None)

    for i in range(num_samples):
        sample = model.lm.generate(
            prev_generated,
            attributes,
            num_samples=1,
            max_gen_len=total_gen_len,
            remove_prompts=(i > 0),
            temp=temp,
            top_k=top_k,
        )

        if sample.shape[2] == total_gen_len:
            prev_generated = torch.clone(sample[..., total_gen_len // 2 :])
        else:
            prev_generated = torch.clone(sample)

        with torch.no_grad():
            gen_audio = model.compression_model.decode(sample, None) 
        del sample

        gen_audio = gen_audio[0].detach().cpu().numpy().transpose(1, 0)
        samples.append(gen_audio)

    return samples


def save_audio(samples, path):
    audio = np.concatenate(samples, axis=0)
    audio = np.squeeze(audio)
    wavfile.write(path, 32000, audio.astype(np.float32))


In [16]:
torch.manual_seed(45)
generated_audio = generate(model, num_samples=1, total_gen_len=1524)
save_audio(generated_audio, "generated_Audio2.wav")

<h2>Conditional Generation</h2>

In [9]:
descriptions = ["A dynamic blend of hip-hop and orchestral elements, with sweeping strings and brass, evoking the vibrant energy of the city."]
model.set_generation_params(duration=30)
generated_audio = model.generate(descriptions)
from audiocraft.data.audio import audio_write

for idx, audio_sample in enumerate(generated_audio):
    filename = f"generated_music{idx:03d}.wav"  
    waveform = audio_sample.cpu().numpy()  
    
    print(f"Saving: {filename}")  
    print(f"Waveform shape: {waveform.shape}")  
    print(f"Waveform data (first 10 samples): {waveform[:10]}")  

    audio_write(filename, audio_sample.cpu(), model.sample_rate)

Saving: generated_music000.wav
Waveform shape: (2, 960000)
Waveform data (first 10 samples): [[ 0.18038043  0.17244084  0.1827398  ... -0.19096588 -0.1991812
  -0.19858283]
 [ 0.03373905  0.02937443  0.08155953 ...  0.0873717   0.08849183
   0.09535955]]


<h2>plot the waveform</h2>

In [None]:
import matplotlib.pyplot as plt
import librosa
import numpy as np
from scipy.io.wavfile import write

# Load the generated WAV file
waveform, sr = librosa.load("generated_music000.wav", sr=None)

# Plot the waveform
plt.figure(figsize=(12, 4))
plt.plot(np.linspace(0, len(waveform)/sr, len(waveform)), waveform)
plt.title("Waveform of Generated Music")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude")
plt.show()
print("Waveform:", waveform[:960000])

<h2>done</h2>