In [None]:
'''
#########################################################################################
This is the code for model training of segment size 0.1 seconds but instead of padding 
we will resample the audio
#########################################################################################
'''

In [1]:
import os
import torchaudio
import torch
from torch.utils.data import Dataset

class AudioDatasetResample(Dataset):
    def __init__(self, lossless_dir, lossy_dir, segment_duration=0.1, target_sample_rate=44000):
        """
        Initializes the dataset and processes songs one by one, ensuring both lossy and lossless
        are resampled to 44kHz if needed.
        """
        self.lossless_files = sorted(
            [os.path.join(lossless_dir, f) for f in os.listdir(lossless_dir) if os.path.isfile(os.path.join(lossless_dir, f))]
        )
        self.lossy_files = sorted(
            [os.path.join(lossy_dir, f) for f in os.listdir(lossy_dir) if os.path.isfile(os.path.join(lossy_dir, f))]
        )

        assert len(self.lossless_files) == len(self.lossy_files), "Mismatch in number of lossless and lossy files!"

        self.segment_duration = segment_duration
        self.target_sample_rate = target_sample_rate
        self.data = []  # Store valid segment pairs in memory

        # Process and add all files
        self.process_and_add()

    def process_and_add(self):
        """
        Processes each song and adds valid segment pairs to the dataset.
        """
        for idx, (lossless_path, lossy_path) in enumerate(zip(self.lossless_files, self.lossy_files)):
            song_data = self.process_pair(lossless_path, lossy_path)
            if song_data:
                self.data.extend(song_data)
            if (idx + 1) % 10 == 0:
                print(f"Processed {idx + 1}/{len(self.lossless_files)} songs...")

        print(f"Dataset created with {len(self.data)} valid segment pairs.")

    def process_pair(self, lossless_path, lossy_path):
        """
        Processes a pair of lossless and lossy files, resampling if necessary, and splitting into segments.
        """
        lossless_waveform, lossless_sr = torchaudio.load(lossless_path)
        lossy_waveform, lossy_sr = torchaudio.load(lossy_path)

        # Resample to target sample rate if needed
        if lossless_sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=lossless_sr, new_freq=self.target_sample_rate)
            lossless_waveform = resampler(lossless_waveform)
        if lossy_sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=lossy_sr, new_freq=self.target_sample_rate)
            lossy_waveform = resampler(lossy_waveform)

        # Segment duration in samples
        segment_size = int(self.target_sample_rate * self.segment_duration)

        # Split waveforms into segments
        lossless_segments = [
            lossless_waveform[:, i:i + segment_size]
            for i in range(0, lossless_waveform.shape[1], segment_size)
            if lossless_waveform[:, i:i + segment_size].shape[1] == segment_size
        ]
        lossy_segments = [
            lossy_waveform[:, i:i + segment_size]
            for i in range(0, lossy_waveform.shape[1], segment_size)
            if lossy_waveform[:, i:i + segment_size].shape[1] == segment_size
        ]

        # Ensure equal number of segments
        if len(lossless_segments) != len(lossy_segments):
            print(f"Skipping {lossless_path} and {lossy_path} due to unequal segments.")
            return []

        return list(zip(lossy_segments, lossless_segments))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns a single pair of lossy and lossless stereo segments.
        """
        return self.data[idx]