In [1]:
import os
import json
import torch
import librosa
import numpy as np
from torch.utils.data import Dataset

In [None]:
#Hyper Parameters
SAMPLING_RATE = 22050
NO_MFCC = 13
HOP_LENGTH = 512
DURATION = 30

In [2]:
class MusicSegmentationDataset(Dataset):
    def __init__(self, root_dir, sr=SAMPLING_RATE, n_mfcc=NO_MFCC, hop_length=HOP_LENGTH, duration=DURATION):
        self.root_dir = root_dir
        self.song_files = os.listdir(os.path.join(self.root_dir, 'songs'))
        self.metadata_dir = os.path.join(self.root_dir, 'metadata')
        self.sr = sr
        self.n_mfcc = n_mfcc
        self.hop_length = hop_length
        self.duration = duration

    def _load_ground_truth(self, song_name):
        metadata_file = os.path.join(self.metadata_dir, f"{song_name}.json")
        if os.path.exists(metadata_file):
            with open(metadata_file, 'r') as json_file:
                metadata = json.load(json_file)
            return {
                'pallavi': torch.tensor(metadata['pallavi']),
                'anupallavi': torch.tensor(metadata['anupallavi']),
                'charanam': torch.tensor(metadata['charanam']),
                'bgm': torch.tensor(metadata['bgm'])
            }
        else:
            return None

    def _extract_features(self, music_file):
        y, sr = librosa.load(music_file, sr=self.sr, duration=self.duration)
        
        # Extract MFCCs
        mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc, hop_length=self.hop_length)
        
        # Extract chroma features
        chroma = librosa.feature.chroma_stft(y=y, sr=sr, hop_length=self.hop_length)
        
        # Extract mel-spectrogram features
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, hop_length=self.hop_length)
        
        return torch.tensor(mfcc.T, dtype=torch.float32), \
               torch.tensor(chroma.T, dtype=torch.float32), \
               torch.tensor(mel_spec.T, dtype=torch.float32)

    def __len__(self):
        return len(self.song_files)

    def __getitem__(self, idx):
        song_name = self.song_files[idx]
        music_file = os.path.join(self.root_dir, 'songs', song_name)
        ground_truth = self._load_ground_truth(song_name)

        if ground_truth is not None:
            mfcc, chroma, mel_spec = self._extract_features(music_file)
            return {
                'mfcc': mfcc,
                'chroma': chroma,
                'mel_spec': mel_spec,
                'ground_truth': ground_truth
            }
        else:
            return None
