In [1]:
import torch
from torch.utils.data import Dataset
import torchaudio
import json
import os
from transformers import T5Tokenizer, T5EncoderModel

In [None]:
# ========== Imports ==========
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
import json
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import T5Tokenizer, T5EncoderModel

# ========== Dataset ==========
class MusicBenchDataset(Dataset):
    def __init__(self, json_path, audio_base_path, tokenizer, max_length=512, sample_rate=16000, n_mels=80, frames=400, hop_size=256):
        self.data = []
        with open(json_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))  # JSON lines format

        self.audio_base_path = audio_base_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.frames = frames
        self.hop_size = hop_size
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate, n_mels=self.n_mels, hop_length=self.hop_size
        )

    def fix_mel_length(self, mel):
        n_mels, frames = mel.shape
        if frames > self.frames:
            mel = mel[:, :self.frames]
        elif frames < self.frames:
            pad = torch.zeros((n_mels, self.frames - frames))
            mel = torch.cat([mel, pad], dim=1)
        return mel

    def fix_wav_length(self, wav):
        target_length = self.frames * self.hop_size
        if wav.shape[0] > target_length:
            wav = wav[:target_length]
        elif wav.shape[0] < target_length:
            pad = torch.zeros(target_length - wav.shape[0])
            wav = torch.cat([wav, pad])
        return wav

    def preprocess_audio(self, filepath):
        waveform, sr = torchaudio.load(filepath)
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        mel = self.mel_transform(waveform).squeeze(0)  # (n_mels, time)
        mel = self.fix_mel_length(mel)
        waveform = waveform.squeeze(0)  # (samples)
        waveform = self.fix_wav_length(waveform)

        return mel, waveform

    def preprocess_text(self, item):
        text = f"{item['main_caption']} {item['alt_caption']} {item['prompt_bpm']} {item['prompt_key']} {item['prompt_bt']} {item['prompt_ch']}"
        return text

    def __getitem__(self, idx):
        item = self.data[idx]
        audio_path = os.path.join(self.audio_base_path, item['location'])

        mel, wav = self.preprocess_audio(audio_path)

        text = self.preprocess_text(item)
        text_inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')

        bpm = item['bpm'] or 120.0  # Default if bpm is None
        bpm = float(bpm) / 300.0  # normalize BPM

        return {
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0),
            'bpm': torch.tensor(bpm, dtype=torch.float32),
            'mel': mel,
            'wav': wav
        }

    def __len__(self):
        return len(self.data)

# ========== Text-to-Mel Model ==========
class TextToMelModel(nn.Module):
    def __init__(self, text_embed_dim, mel_bins=80, frames=400):
        super().__init__()
        self.frames = frames
        self.fc = nn.Sequential(
            nn.Linear(text_embed_dim + 1, 1024),
            nn.ReLU(),
            nn.Linear(1024, mel_bins * frames)
        )
        self.mel_bins = mel_bins

    def forward(self, text_embeddings, bpm):
        x = torch.cat([text_embeddings, bpm.unsqueeze(1)], dim=1)
        output = self.fc(x)
        output = output.view(-1, self.mel_bins, self.frames)
        return output

# ========== SimpleMel2Wav Vocoder ==========
class SimpleMel2Wav(nn.Module):
    def __init__(self, n_mels=80, upsample_scales=[8, 8, 4, 1]):
        super().__init__()
        self.initial = nn.Conv1d(n_mels, 512, kernel_size=7, padding=3)
        layers = []
        in_channels = 512
        for scale in upsample_scales:
            layers.append(nn.ConvTranspose1d(
                in_channels, in_channels // 2, kernel_size=scale * 2, stride=scale, padding=scale // 2))
            layers.append(nn.LeakyReLU(0.2))
            in_channels = in_channels // 2
        self.upsample = nn.Sequential(*layers)
        self.final = nn.Conv1d(in_channels, 1, kernel_size=7, padding=3)
        self.tanh = nn.Tanh()

    def forward(self, mel):
        x = self.initial(mel)
        x = self.upsample(x)
        x = self.final(x)
        x = self.tanh(x)
        x = x.squeeze(1)  # [batch, time]

        target_length = mel.shape[-1] * 256  # frames × hop_size
        if x.shape[1] > target_length:
            x = x[:, :target_length]
        elif x.shape[1] < target_length:
            pad = torch.zeros((x.shape[0], target_length - x.shape[1]), device=x.device)
            x = torch.cat([x, pad], dim=1)
        return x


# ========== Helper: Plot Mel Spectrogram ==========
def plot_mel(mel, title="Mel Spectrogram"):
    plt.figure(figsize=(10, 4))
    plt.imshow(mel.cpu().detach().numpy(), aspect='auto', origin='lower')
    plt.title(title)
    plt.xlabel("Frames")
    plt.ylabel("Mel Bins")
    plt.colorbar()
    plt.show()

# ========== Training ==========
def train(mel_model, vocoder_model, text_encoder, dataloader, optimizer, mel_loss_fn, wav_loss_fn, device):
    mel_model.train()
    vocoder_model.train()
    text_encoder.eval()

    total_loss = 0
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        bpm = batch['bpm'].to(device)
        mel_target = batch['mel'].to(device)
        wav_target = batch['wav'].to(device)

        with torch.no_grad():
            text_outputs = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            text_embeds = text_outputs.last_hidden_state.mean(dim=1)

        mel_pred = mel_model(text_embeds, bpm)
        wav_pred = vocoder_model(mel_pred)

        mel_loss = mel_loss_fn(mel_pred, mel_target)
        wav_loss = wav_loss_fn(wav_pred, wav_target)

        loss = mel_loss + wav_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)
    return avg_loss

# ========== Evaluation ==========
def evaluate(mel_model, vocoder_model, text_encoder, dataloader, mel_loss_fn, wav_loss_fn, device):
    mel_model.eval()
    vocoder_model.eval()
    text_encoder.eval()

    total_loss = 0
    pbar = tqdm(dataloader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            bpm = batch['bpm'].to(device)
            mel_target = batch['mel'].to(device)
            wav_target = batch['wav'].to(device)

            text_outputs = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            text_embeds = text_outputs.last_hidden_state.mean(dim=1)

            mel_pred = mel_model(text_embeds, bpm)
            wav_pred = vocoder_model(mel_pred)

            mel_loss = mel_loss_fn(mel_pred, mel_target)
            wav_loss = wav_loss_fn(wav_pred, wav_target)

            loss = mel_loss + wav_loss
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)
    return avg_loss

# ========== Inference ==========
def generate(mel_model, vocoder_model, text_encoder, tokenizer, text_prompt, bpm_value, device):
    mel_model.eval()
    vocoder_model.eval()
    text_encoder.eval()

    inputs = tokenizer(text_prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    bpm_tensor = torch.tensor([bpm_value / 300.0], dtype=torch.float32).to(device)

    with torch.no_grad():
        text_outputs = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_embeds = text_outputs.last_hidden_state.mean(dim=1)

        mel_pred = mel_model(text_embeds, bpm_tensor)
        plot_mel(mel_pred.squeeze(0), title="Generated Mel Spectrogram")

        wav = vocoder_model(mel_pred)
        return wav

# ========== Main ==========
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    text_encoder = T5EncoderModel.from_pretrained('t5-small').to(device)

    json_path = 'MusicBench_train.json'
    audio_base_path = 'datashare'

    full_dataset = MusicBenchDataset(
        json_path=json_path,
        audio_base_path=audio_base_path,
        tokenizer=tokenizer
    )

    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    mel_model = TextToMelModel(text_embed_dim=512, mel_bins=80, frames=400).to(device)
    vocoder_model = SimpleMel2Wav(n_mels=80).to(device)

    optimizer = torch.optim.Adam(list(mel_model.parameters()) + list(vocoder_model.parameters()), lr=1e-4)
    mel_loss_fn = nn.MSELoss()
    wav_loss_fn = nn.L1Loss()

    epochs = 4
    for epoch in range(epochs):
        print(f"\n=== Epoch {epoch+1}/{epochs} ===")
        train_loss = train(mel_model, vocoder_model, text_encoder, train_loader, optimizer, mel_loss_fn, wav_loss_fn, device)
        val_loss = evaluate(mel_model, vocoder_model, text_encoder, val_loader, mel_loss_fn, wav_loss_fn, device)
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Inference Example
    example_text = "A calm piano melody with a soft beat and slow tempo in C major."
    bpm_value = 90.0
    wav = generate(mel_model, vocoder_model, text_encoder, tokenizer, example_text, bpm_value, device)

    if wav is not None:
        wav = wav.squeeze(0)  # Remove batch dimension if it's [1, time]
        wav = wav.unsqueeze(0)  # Now [1, time] (mono)
        torchaudio.save('generated_output.wav', wav.cpu(), 16000)


if __name__ == "__main__":
    main()



=== Epoch 1/4 ===


Training:   0%|          | 29/5937 [00:03<12:49,  7.67it/s, loss=9.25e+3]