In [1]:
# import os
# from pydub import AudioSegment

# input_root = "audio"
# output_root = "audio_wav"

# for subdir in os.listdir(input_root):
#     input_subdir = os.path.join(input_root, subdir)
#     output_subdir = os.path.join(output_root, subdir)

#     if not os.path.isdir(input_subdir):
#         continue

#     os.makedirs(output_subdir, exist_ok=True)

#     for filename in os.listdir(input_subdir):
#         if filename.endswith(".mp3"):
#             input_path = os.path.join(input_subdir, filename)
#             output_path = os.path.join(output_subdir, filename.replace(".mp3", ".wav"))

#             try:
#                 audio = AudioSegment.from_mp3(input_path)
#                 audio = audio.set_channels(1).set_frame_rate(32000)
#                 audio.export(output_path, format="wav")
#             except Exception as e:
#                 print(f"⚠️ Failed to convert: {input_path}")
#                 print(f"   Reason: {e}")


In [2]:
import pandas as pd

# Load the uploaded CSV file
csv_path = "./annotations_final.csv"
df = pd.read_csv(csv_path, sep="\t")

# Display the first few rows to understand its structure
df.head(10)

Unnamed: 0,clip_id,no voice,singer,duet,plucking,hard rock,world,bongos,harpsichord,female singing,...,rap,metal,hip hop,quick,water,baroque,women,fiddle,english,mp3_path
0,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,f/american_bach_soloists-j_s__bach_solo_cantat...
1,6,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,f/american_bach_soloists-j_s__bach_solo_cantat...
2,10,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,f/american_bach_soloists-j_s__bach_solo_cantat...
3,11,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,f/american_bach_soloists-j_s__bach_solo_cantat...
4,12,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,f/american_bach_soloists-j_s__bach_solo_cantat...
5,14,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,c/lvx_nova-lvx_nova-01-contimune-30-59.mp3
6,19,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,c/lvx_nova-lvx_nova-01-contimune-175-204.mp3
7,21,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,c/lvx_nova-lvx_nova-01-contimune-233-262.mp3
8,23,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,c/lvx_nova-lvx_nova-01-contimune-291-320.mp3
9,25,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0/american_bach_soloists-j_s__bach__cantatas_v...


In [3]:
# size of the dataset
print(f"Dataset size: {len(df)} rows")

Dataset size: 25863 rows


In [4]:
import json

# Get the tag columns (all except clip_id and mp3_path)
tag_columns = df.columns.difference(['clip_id', 'mp3_path'])

# Output list
samples = []

for _, row in df.iterrows():
    tags = [tag.replace("_", " ") for tag in tag_columns if row[tag] == 1]
    if not tags:
        continue  # skip samples with no tags
    
    # Generate simple natural language prompt
    prompt = f"A music clip with " + ", ".join(tags[:-1]) + (" and " + tags[-1] if len(tags) > 1 else tags[0]) + "."

    samples.append({
        "audio_filepath": f"audio_wav/{row['mp3_path'].replace('.mp3', '.wav')}",
        "text": prompt
    })

# Save to JSONL
jsonl_path = "./audio_wav/magnatagatune_text_audio_pairs.jsonl"
with open(jsonl_path, "w") as f:
    for sample in samples:
        f.write(json.dumps(sample) + "\n")

jsonl_path

'./audio_wav/magnatagatune_text_audio_pairs.jsonl'

In [5]:
from transformers import EncodecModel, AutoProcessor
import librosa
import torch
import json
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EncodecModel.from_pretrained("facebook/encodec_32khz").to(device)
processor = AutoProcessor.from_pretrained("facebook/encodec_32khz")

input_jsonl = "audio_wav/magnatagatune_text_audio_pairs.jsonl"
output_jsonl = "audio_wav/magnatagatune_tokenized_32khz.jsonl"
Path(input_jsonl).parent.mkdir(parents=True, exist_ok=True)
Path(output_jsonl).parent.mkdir(parents=True, exist_ok=True)

MAX_SAMPLES = 100

with open(input_jsonl, "r") as fin, open(output_jsonl, "w") as fout:
    for idx, line in enumerate(fin):
        if idx >= MAX_SAMPLES:
            break

        item = json.loads(line)
        audio_path = item["audio_filepath"]

        try:
            waveform, sr = librosa.load(audio_path, sr=None, mono=True)
            # print(f"Processing {audio_path} with shape {waveform.shape} and sample rate {sr}")
            inputs = processor(raw_audio=waveform, sampling_rate=sr, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model(**inputs)
                codes = outputs.audio_codes[0].cpu().tolist()

            item["audio_tokens"] = codes
            fout.write(json.dumps(item) + "\n")
        except Exception as e:
            print(f"⚠️ Failed on {audio_path}: {e}")

output_jsonl


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


'audio_wav/magnatagatune_tokenized_32khz.jsonl'

In [6]:
from transformers import AutoTokenizer, AutoModel
import json

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text_encoder = AutoModel.from_pretrained("bert-base-uncased")
text_encoder.eval()

input_path = "audio_wav/magnatagatune_tokenized_32khz.jsonl"
output_path = "audio_wav/flattened_token_with_text_embedding.jsonl"

Path(output_path).parent.mkdir(parents=True, exist_ok=True)

MAX_SAMPLES = 100

with open(input_path, "r") as fin, open(output_path, "w") as fout:
    for idx, line in enumerate(fin):
        if idx >= MAX_SAMPLES:
            break

        item = json.loads(line)
        text = item["text"]
        tokens = item["audio_tokens"] 

        try:
            transposed = list(zip(*tokens)) 
            flattened = [tok for frame in transposed for tok in frame] 

            with torch.no_grad():
                encoded = tokenizer(text, return_tensors="pt")
                embedding = text_encoder(**encoded).last_hidden_state.mean(dim=1).squeeze().tolist()

            fout.write(json.dumps({
                "text": text,
                "text_embedding": embedding,
                "audio_flat_tokens": flattened
            }) + "\n")
        except Exception as e:
            print(f"Failed on sample {idx}: {e}")

output_path

'audio_wav/flattened_token_with_text_embedding.jsonl'

In [16]:
from torch.utils.data import Dataset
from typing import List

# 定义 Dataset 类
class MusicGenDataset(Dataset):
    def __init__(self, jsonl_path: str, max_audio_tokens: int = 1024):
        self.data = []
        with open(jsonl_path, "r") as f:
            for line in f:
                item = json.loads(line)
                audio_tokens = item["audio_flat_tokens"][:max_audio_tokens]
                text_embed = item["text_embedding"]

                embed_dim = 768
                num_text_tokens = 8
                if len(text_embed) != embed_dim:
                    continue

                text_tokens = torch.tensor(text_embed, dtype=torch.float).repeat(num_text_tokens, 1)
                text_token_ids = torch.full((num_text_tokens,), -100)

                MAX_AUDIO_TOKENS = 2048
                flat_audio = [tok for frame in zip(*audio_tokens) for tok in frame]
                flat_audio = flat_audio[:MAX_AUDIO_TOKENS + 1]
                input_ids = torch.tensor(flat_audio[:-1], dtype=torch.long)
                labels    = torch.tensor(flat_audio[1:], dtype=torch.long) 

                self.data.append({
                    "text_embed": text_tokens,         # shape: [N_text, 768]
                    "input_ids": input_ids,            # shape: [T]
                    "labels": labels,                  # shape: [T]
                })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 构建 Dataset 并保存用于训练
dataset_path = "audio_wav/flattened_token_with_text_embedding.jsonl"
dataset = MusicGenDataset(dataset_path)

# 保存为 .pt 文件
torch.save(dataset, "audio_wav/musicgen_dataset.pt")
"audio_wav/musicgen_dataset.pt"

'audio_wav/musicgen_dataset.pt'

In [None]:
import torch.nn as nn

class MusicGenTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, n_layers=6, n_heads=8):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.prefix_proj = nn.Linear(embed_dim, embed_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=n_heads, dim_feedforward=2048, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

        
    def forward(self, text_embed, input_ids):
        tok_emb = self.token_embedding(input_ids)   # [B, T, 768]
        if tok_emb.dim() == 4:
            tok_emb = tok_emb.squeeze(2)

        prefix = self.prefix_proj(text_embed)       # [B, 8, 768]
        x = torch.cat([prefix, tok_emb], dim=1)     # [B, 8+T, 768]

        attn_mask = torch.triu(torch.ones(x.size(1), x.size(1)), 1).bool().to(x.device)
        out = self.transformer(x, mask=attn_mask)
        return self.lm_head(out[:, prefix.size(1):])  # return only audio token logits

In [33]:
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # for debugging
jsonl_path = "audio_wav/flattened_token_with_text_embedding.jsonl"
vocab_size = 1024
embed_dim = 768
max_audio_tokens = 1024
batch_size = 1
num_epochs = 10
lr = 1e-4
max_len = 1024 
# dataset = torch.load("audio_wav/musicgen_dataset.pt", weights_only=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = MusicGenDataset(jsonl_path, max_audio_tokens=max_audio_tokens)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = MusicGenTransformer(vocab_size=vocab_size, embed_dim=embed_dim)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scaler = GradScaler()

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in loader:
        text_embed = batch["text_embed"].to(device)              # [B, 8, 768]
        input_ids  = batch["input_ids"].to(device)[:, :max_len]  # [B, T]
        labels     = batch["labels"].to(device)[:, :max_len]      # [B, T]

        # 🔒 检查标签合法
        if labels.max() >= vocab_size:
            print("⚠️ 跳过 batch: label 越界")
            continue

        with autocast():
            logits = model(text_embed, input_ids)                # [B, T, V]
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-100  # 可选
            )

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import torchaudio
text_embed = torch.randn(1, 8, 768).to(device)


input_ids = torch.full((1, 1), fill_value=0, dtype=torch.long).to(device)  # 初始 token

generated = []
max_len = 1000

with torch.no_grad():
    for _ in range(max_len):
        logits = model(text_embed, input_ids)  # [B, T, V]
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  # greedy
        input_ids = torch.cat([input_ids, next_token], dim=1)
        generated.append(next_token.item())

audio_tokens = input_ids[0, 1:].cpu().tolist()


codec = EncodecModel.from_pretrained("facebook/encodec_32khz").to("cuda")
processor = AutoProcessor.from_pretrained("facebook/encodec_32khz")

num_codebooks = 8
audio_tokens_tensor = torch.tensor(audio_tokens).view(1, -1)  # [1, T]
audio_tokens_tensor = audio_tokens_tensor.expand(num_codebooks, -1).unsqueeze(0)  # [1, 8, T]

with torch.no_grad():
    wav = codec.decode(audio_tokens_tensor)[0]  # [1, 1, samples]
    wav = wav.squeeze().cpu()  # [samples]

torchaudio.save("output.wav", wav.unsqueeze(0), 32000)
print("result saved as output.wav")