In [1]:
!pip install transformers datasets

Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Downloading huggingface_hub-0.36.0-py3-none-any.whl (566 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m566.1/566.1 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (47.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: pyarrow, huggingface-hub
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 19.0.1
    Uninstalling pyarrow-19.0.1:
      Successfully uninstalled pyarrow-19.0.1
  Attempting uninstall: huggingface-hub
    Found existing installation: huggingface-hub 1.0.0rc2
    Uninstal

In [2]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import csv
import json
import logging
import os
from typing import Optional, Any, TypedDict

import numpy as np
import scipy.io.wavfile
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
from transformers import pipeline, Pipeline

from concurrent.futures import ThreadPoolExecutor, as_completed

logger = logging.getLogger(__name__)

2025-10-24 17:08:07.614537: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761325687.788151      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761325687.841128      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
class PromptItemLoader(TypedDict):
    id: str
    prompt: str


class OutputGeneratedAudioItem(TypedDict):
    id: str
    prompt: str
    audio_path: str

In [5]:
def read_json(file_path: str) -> dict[str, Any] | list[dict[str, Any]]:
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


def read_csv(file_path: str) -> dict[str, Any]:
    with open(file_path, "r", encoding="utf-8") as f:
        csv_reader = csv.DictReader(f)
        rows = [row for row in csv_reader]

    return rows


class DatasetLoader(Dataset):
    """
    Dataset para representar input data para el modelo `taste-music-dataset`.
    Attributes:
        file_path (str): La ruta al archivo JSON que contiene los datos de Spanio.
    """

    def __init__(self, file_path, format_: str = "json"):
        """
        son_file_path (str): Ruta del archivo input con prompts.
        """
        super().__init__()
        self.file_path = file_path
        self.format_ = format_
        self.records: list[PromptItemLoader] = []
        self._load_data(self.format_)

    def _load_data(self, file_format: str) -> None:
        read_functions = {
            "json": read_json,
            "csv": read_csv,
        }
        try:
            self.records = read_functions[file_format](self.file_path)
        except Exception as e:
            print(f"Error cargando {self.file_path}: {e}")
            logger.critical(e, exc_info=True)

    def __len__(self) -> int:
        """
        Returns:
            int: Número de registros disponibles.
        """
        return len(self.records)

    def __getitem__(self, idx) -> PromptItemLoader:
        """
        Parameters:
            idx (int): Índice del registro a retornar.
        Returns
            El item del dataset en la posicion recibida por parametro.
        """
        if not 0 <= idx < len(self.records):
            raise IndexError(
                f"indice {idx} fuera de rango para dataset: {len(self.records)}."
            )
        return self.records[idx]

In [6]:
def load_musicgen_pipeline(
    model_name: str = "csc-unipd/tasty-musicgen-small",
    prefer_gpu: bool = True,
    dtype: Optional[torch.dtype] = None,
) -> Pipeline:
    device = "cuda:0" if prefer_gpu and torch.cuda.is_available() else "cpu"
    if dtype is None and device.startswith("cuda"):
        dtype = torch.float16

    pipe: Pipeline = pipeline(
        task="text-to-audio", model=model_name, torch_dtype=dtype, device=device
    )

    return pipe

In [8]:
def generate_audio_from_prompts(
    synthesiser, dataset, output_dir="generated_music", sample_rate=32000
):
    os.makedirs(output_dir, exist_ok=True)
    results = []

    print(f"Generando música para {len(dataset)} prompts...\n")

    for record in tqdm(dataset.records):
        text_prompt = record["prompt"]
        file_id = record["id"]

        try:
            # 1. Generar la música con el modelo.
            # output es un diccionario: audio(array NumPy con la señal de audio) y sampling_rate (frecuencia de muestreo del modelo).
            output = synthesiser(text_prompt, forward_params={"do_sample": True})

            # 2. Extraer datos del audio.
            audio_data = output["audio"]  # La onda de sonido (las muestras del audio).
            sr = output.get(
                "sampling_rate", sample_rate
            )  # La frecuencia de muestreo reportada por el modelo.

            # 3. Guardar el audio generado.
            output_path = os.path.join(output_dir, f"{file_id}.wav")
            scipy.io.wavfile.write(
                output_path, rate=sr, data=audio_data
            )  # Escribir el archivo .wav con la señal y la frecuencia.

            # 4. Registrar los resultados.
            results.append(
                {"id": file_id, "description": text_prompt, "audio_path": output_path}
            )
        except Exception as e:
            print(f" Error generando {file_id}: {e}")
            continue

    print(f"\n {len(results)} archivos de audio generados en: {output_dir}")
    return results

In [None]:
dataset = DatasetLoader("/kaggle/input/prompts/spanio_prompts_v2.csv", format_="csv")
synth = load_musicgen_pipeline()

generate_audio_from_prompts(
    synthesiser=synth, dataset=dataset, output_dir="/kaggle/working/generated_music"
)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.36G [00:00<?, ?B/s]



generation_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Device set to use cuda:0


Generando música para 100 prompts...



  6%|▌         | 6/100 [03:45<58:30, 37.35s/it]  