In [8]:
%pip install librosa
%pip install pydub
%pip install peft

Collecting librosa
  Downloading librosa-0.10.2.post1-py3-none-any.whl.metadata (8.6 kB)
Collecting audioread>=2.1.9 (from librosa)
  Using cached audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.60.0-cp312-cp312-win_amd64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Using cached soundfile-0.12.1-py2.py3-none-win_amd64.whl.metadata (14 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-0.5.0.post1-cp312-abi3-win_amd64.whl.metadata (5.6 kB)
Collecting lazy-loader>=0.1 (from librosa)
  Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.0-cp312-cp312-win_amd64.whl.metadata (8.6 kB)
Collecting llvmlite<0.44,>=0.43.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.43.0-cp312-cp312-win_amd64.whl.metadata (4.9 kB)
D

In [14]:
# Import 주요 라이브러리
import os
import torch
import json
import librosa
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer
from pydub import AudioSegment
from peft import PromptEncoder, PromptEncoderConfig
import soundfile as sf

In [15]:
# 데이터셋 클래스 정의
class JSONAudioDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, "r") as f:
            data = json.load(f)["data"]
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        audio_path = sample["audio_file"]
        text_condition = f"{sample['description']} Keywords: {', '.join(sample['keywords'])}. Moods: {', '.join(sample['moods'])}."
        audio = self.load_audio(audio_path)
        return audio, text_condition

    def load_audio(self, path):
        if path.endswith(".mp3"):
            wav_path = path.replace(".mp3", ".wav")
            if not os.path.exists(wav_path):
                self.convert_mp3_to_wav(path, wav_path)
            path = wav_path
        try:
            audio, _ = librosa.load(path, sr=32000)
            return torch.tensor(audio)
        except Exception as e:
            print(f"Error loading audio file {path}: {e}")
            return torch.zeros(1)

    def convert_mp3_to_wav(self, mp3_path, wav_path):
        try:
            audio = AudioSegment.from_mp3(mp3_path)
            audio.export(wav_path, format="wav")
            print(f"Converted {mp3_path} to {wav_path}")
        except Exception as e:
            print(f"Error converting {mp3_path} to WAV: {e}")


In [16]:
# PEFT 프롬프트 조건 공급자 정의
class PEFTPConditionProvider(torch.nn.Module):
    def __init__(self, prompt_length, hidden_size, num_transformer_submodules, num_attention_heads, num_layers):
        super().__init__()
        self.config = PromptEncoderConfig(
            task_type="TEXT_GENERATION",
            num_virtual_tokens=prompt_length,
            token_dim=hidden_size,
            encoder_hidden_size=hidden_size,
            encoder_num_layers=2,
            encoder_dropout=0.1,
            num_transformer_submodules=num_transformer_submodules
        )
        self.prompt_encoder = PromptEncoder(self.config)
        self.num_virtual_tokens = prompt_length

    def forward(self, tokens):
        batch_size = tokens.size(0)
        indices = torch.arange(self.num_virtual_tokens, device=tokens.device).unsqueeze(0).expand(batch_size, -1)
        prompt_embeds = self.prompt_encoder(indices)
        if len(prompt_embeds.shape) == 4:
            prompt_embeds = prompt_embeds.squeeze(0)
        return torch.cat([prompt_embeds, tokens], dim=1)


In [17]:
# 학습 루프 정의
def train_model(model, tokenizer, dataloader, device, epochs, grad_acc_steps, lr, checkpoint_dir):
    optimizer = AdamW(
        list(model.lm.parameters()) + list(model.condition_provider.parameters()), 
        lr=lr
    )
    loss_fn = torch.nn.MSELoss()

    model.lm.train()
    model.condition_provider.train()

    for epoch in range(epochs):
        total_loss = 0

        for i, (audio, text) in enumerate(dataloader):
            audio = audio.to(device)
            text_tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
            prompts = model.condition_provider(text_tokens)

            num_codebooks = model.lm.num_codebooks
            hidden_size = model.lm.embedding_dim if hasattr(model.lm, "embedding_dim") else 768
            audio = audio.unsqueeze(1).expand(-1, num_codebooks, -1).to(torch.long)

            try:
                outputs = model.lm(audio, prompts)
            except Exception as e:
                print(f"Error during training: {e}")
                raise e

            loss = loss_fn(outputs, audio)
            total_loss += loss.item()

            loss.backward()
            if (i + 1) % grad_acc_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")

        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.lm.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }, checkpoint_path)


In [18]:
# 음악 생성 함수 정의
def generate_music(model, tokenizer, text_condition, device, output_path):
    model.eval()
    tokenized = tokenizer(text_condition, return_tensors="pt", padding=True, truncation=True)
    tokens = tokenized.input_ids.to(device)
    prompts = model.condition_provider(tokens)
    with torch.no_grad():
        generated_audio = model.generate(prompts)
    sf.write(output_path, generated_audio.cpu().numpy(), samplerate=32000)
    print(f"Generated music saved to: {output_path}")


In [19]:
# 주요 실행 코드
JSON_PATH = "data/Silent-Night.json"
CHECKPOINT_DIR = "checkpoints"
EPOCHS = 10
BATCH_SIZE = 16
GRAD_ACC_STEPS = 1
LR = 1e-4
PROMPT_LENGTH = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from audiocraft.models import MusicGen

model = MusicGen.get_pretrained("small")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

prompt_provider = PEFTPConditionProvider(
    prompt_length=PROMPT_LENGTH,
    hidden_size=768,
    num_transformer_submodules=12,
    num_attention_heads=12,
    num_layers=12,
)
model.condition_provider = prompt_provider
model.lm = model.lm.to(DEVICE)
model.condition_provider = model.condition_provider.to(DEVICE)

dataset = JSONAudioDataset(JSON_PATH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

train_model(model, tokenizer, dataloader, DEVICE, EPOCHS, GRAD_ACC_STEPS, LR, CHECKPOINT_DIR)

generate_music(model, tokenizer, "A warm and cozy winter melody.", DEVICE, "generated_music.wav")


ModuleNotFoundError: No module named 'audiocraft'