In [None]:
!pip install torch librosa scikit-learn soundfile transformers datasets numpy matplotlib pandas torchcodec 

In [3]:
import torch
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from typing import Optional

#version_1

class MusicgenWithResiduals:
    def __init__(
        self,
        model_name: str = "facebook/musicgen-small",  # or "facebook/musicgen-medium", "facebook/musicgen-large"
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        print(f"Loading model {model_name} to {device}...")
        self.model = MusicgenForConditionalGeneration.from_pretrained(
            model_name,
            trust_remote_code=True,
            output_hidden_states=True
        ).to(device)

        print("Loading processor...")
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model.freeze_text_encoder()
        self.model.freeze_audio_encoder()

        self.device = device
        self.hidden_states = {}

        def hook_fn(module, input, output):
            #print("Hook triggered!")
            if hasattr(output, "hidden_states"):
                #print("Hidden states captured!")
                layer_names=[]
                if hasattr(self.model, 'decoder') and hasattr(self.model.decoder.model.decoder, 'layers'):
                    layer_names += [f"decoder.layer.{i}" for i in range(len(self.model.decoder.model.decoder.layers))]
                #print(len(output.hidden_states))
                self.hidden_states = {
                    layer_names[i]: output.hidden_states[i+1]
                    for i in range(len(layer_names))
                }
            else:
                print(f"Output structure: {type(output)} - {output}")

        self.model.decoder.model.decoder.register_forward_hook(hook_fn)
        print("Model ready!")

    def generate_with_residuals(
        self,
        text: str = None,
        audio: Optional[torch.Tensor] = None,
        sampling_rate: int= None,
        max_new_tokens: int = 10,
        temperature=1e-3,
        **kwargs
    ):

        self.model.decoder.config.output_hidden_states = True
        inputs = {}

        if text is None and audio is None:
            inputs = self.model.get_unconditional_inputs(num_samples=1)
        else:
            inputs = self.processor(
                text=text,
                audio=audio,
                sampling_rate=sampling_rate,
                padding=True,
                return_tensors="pt"
            ).to(self.device)

        # Move inputs to device
        #inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                output_hidden_states=True,
                return_dict_in_generate=True,
                max_new_tokens=max_new_tokens,
            )

        return {
            "audio_values": outputs.sequences,
            "residual_streams": self.hidden_states,  # Return the dictionary of hidden states
            "sampling_rate": self.model.config.audio_encoder.sampling_rate
        }


2025-08-23 12:53:34.833312: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755953615.136226      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755953615.248247      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
# get the residuals from given input samples|
import numpy as np
import librosa
from datasets import load_dataset
import re
import os

def sanitize_string(s):
    return re.sub(r'[^a-zA-Z0-9]', '', s)

def process_with_residuals(model, data, count):
    target_sr = 32000
    #print(data)
    audio = np.array(data['audio'][0]['array'])
    sr = data['audio'][0]['sampling_rate']
    genre = data['genre'][0]
    prompt = f"Generate {genre} music continuing the given audio"
    #print(audio['array'].shape)
    audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
    audio_segments = [audio[i*target_sr:(i+10)*target_sr] for i in range(0, int(audio.shape[0]//target_sr), 10)]
    result = []
    for segment in audio_segments:
        #print(f"Segment shape: {segment.shape} segment type: {type(segment)}")
        outputs = model.generate_with_residuals(
            text=prompt,
            audio=segment,
            sampling_rate = target_sr,
            max_new_tokens=512,
            guidance_scale=3.0,
            do_sample=True
        )
        print('generated outputs')
        # Get the residual stream from the last layer
        residual = np.array([outputs['residual_streams'][i].detach().cpu().numpy() for i in outputs['residual_streams']])
        
        # Create result dictionary with all original features and new data
        result.append({
            'genre': genre,
            'generated_audio': outputs['audio_values'].detach().cpu().numpy(),
            'residual_stream': residual,
            'sampling_rate': outputs['sampling_rate'],
            'prompt_used': prompt
        })
    
    return result

def save_activations(
    save_dir,
    dataset,
):
    os.makedirs(save_dir, exist_ok=True)
    print('Made folder')
    idx = 0
    for data in dataset.iter(batch_size = 1):
        genre = data['genre'][0]
        name = f"{genre}_{idx}"
        processed_data = process_with_residuals(model, data, idx)
        for result in range(len(processed_data)):
            print(f"processed {name}_{result}")
            # Save using numpy's save function
            save_path = os.path.join(save_dir, f"{name}_{result}.npz")
            np.savez(
                save_path,
                **processed_data[result]
            )
            
            if idx % 1 == 0:  # Print progress every 10 items
                print(f"Processed {idx} samples")
        idx+=1

In [None]:
model = MusicgenWithResiduals()
lewtun_modified = load_dataset("roovy54/lewtun_music_genres_modified", streaming=True)
save_activations('lewtun', lewtun_modified['train'])

In [None]:
%%bash
cd /kaggle/working/fma_medium
zip -r /kaggle/working/output_lewtun.zip .
rmdir lewtun