# Data Preparation using Pytorch Dataset and DataLoader

In [1]:
import os
import pandas as pd
import numpy as np
import librosa as lr
import matplotlib.pyplot as plt
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio



In [9]:
print(torch.__version__)
print(torchaudio.__version__)

2.4.0+cu121
2.4.0+cu121


In [2]:
DATASET_DIR = "../../extracted_audio"
DATAINFO_DIR = "../../dataset_info"

In [5]:
def create_melspectogram(files, split_type):
    melspectograms = []
    for file in files:
        audio_path = os.path.join(DATASET_DIR, split_type, file)
        wfm, sr = torchaudio.load(audio_path)
        melspectogram = torchaudio.transforms.MelSpectrogram(
            n_fft=1024, n_mels=128, hop_length=512, sample_rate=sr, norm="slaney"
        )(wfm)
        melspectogram = melspectogram.to(torch.float16)
        melspectograms.append(melspectogram)
    return melspectograms

In [6]:
train_files = sorted(os.listdir(os.path.join(DATASET_DIR, "train")))

train_spectograms = create_melspectogram(train_files, "train")

In [None]:
# load spectrograms
with open("train_melspectograms.pkl", "rb") as f:
    train_spectograms = pickle.load(f)

In [22]:
def get_max_shape(all_spectograms):
    MAX_SHAPE = max([spect.shape[2] for spect in all_spectograms])
    return MAX_SHAPE
    
def pad_spectograms(all_spectograms):
    MAX_SHAPE = get_max_shape(all_spectograms)
    
    for i in range(len(all_spectograms)):
        spectogram = all_spectograms[i]
        pad_len = MAX_SHAPE - spectogram.shape[2]
        all_spectograms[i] = torch.nn.functional.pad(spectogram, (0, pad_len))
        
    return all_spectograms

In [None]:
padded_train_spectograms = pad_spectograms(train_spectograms)

In [15]:
with open("train_melspectograms.pkl", "wb") as f:
    pickle.dump(train_spectograms, f)

In [6]:
class AudioDataset(Dataset):
    def __init__(self, dataset_info_dir, spectograms_file, split_type):
        self.spectograms = np.load(spectograms_file)
        if split_type == "train":
            df = pd.read_csv(os.path.join(dataset_info_dir, "train_split_augmented.csv"))
            self.labels = df["PHQ_Binary"]
        else:
            df = pd.read_csv(os.path.join(dataset_info_dir, f"{split_type}_split.csv"))
            self.labels = df["PHQ_Binary"]
        
    def __len__(self):
        return len(self.spectograms)
    
    def __getitem__(self, idx):
        spectogram = self.spectograms[idx]
        label = self.labels[idx]
        
        return torch.tensor(spectogram), torch.tensor(label)