In [None]:
# download test data
!wget -O GTZAN_TEST.zip 'https://www.aliyundrive.com/drive/folder/6389d74426d2a0b9fba54e67bc7778c7f299e6d9/GTZAN_TEST.zip' --no-check-certificate

In [None]:
# unzip
!unzip GTZAN_TEST.zip

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
import torchaudio
from torch.utils.data import DataLoader
import pandas as pd
import os
from itertools import product
from collections import namedtuple
from collections import OrderedDict
from IPython.display import display,clear_output
import time
import json
from torchsummary import summary
import matplotlib.pyplot as plt

torch.set_printoptions(linewidth=120)

In [None]:
# Add a column to the alphabetical list of tag styles in the tag file in numeric format
# test dataset
ANNOTATIONS_FILE = "./GTZAN_TEST/features_30_sec_test.csv"
dataframe = pd.read_csv(ANNOTATIONS_FILE)

labels = set()
for row in range(len(dataframe)):
    labels.add(dataframe.iloc[row, -1])
labels_list = []
for label in labels:
    labels_list.append(label)
sorted_labels = sorted(labels_list)
sorted_labels
mapping = {}
for index, label in enumerate(sorted_labels):
    mapping[label] = index
dataframe["num_label"] = dataframe["label"]
new_dataframe = dataframe.replace({"num_label": mapping})
new_dataframe

In [None]:
new_dataframe.to_csv("features_30_sec_test_final.csv")

In [None]:
# Data pre-processing class
# 
class GTZANDataset(Dataset):
    def __init__(self,
                 annotations_file,
                 audio_dir,
                 transformation,
                 target_sample_rate,
                 num_samples,
                 device):
        # Load tag files
        self.annotations = pd.read_csv(annotations_file)
        # Load audio address
        self.audio_dir = audio_dir
        # set the device
        self.device = device
        # Merle spectrum data loaded into the device
        self.transformation = transformation.to(self.device)
        # Set sampling rate
        self.target_sample_rate = target_sample_rate
        # Set the number of samples
        self.num_samples = num_samples
        
        
    # Returns how many audio files there are
    def __len__(self):
        return len(self.annotations)

    
    # The data, tags, and paths of the audio can be obtained by means of arrays
    def __getitem__(self, index):
        # Get song path
        audio_sample_path = self._get_audio_sample_path(index)
        # get tag
        label = self._get_audio_sample_label(index)
        # signal  sr : rate
        signal, sr = torchaudio.load(audio_sample_path)
        signal = signal.to(self.device)
        # control
        signal = self._resample_if_necessary(signal, sr)
        # Dual channel->single channel
        signal = self._mix_down_if_necessary(signal)
        # Control the number of samples
        signal = self._cut_if_necessary(signal)
        signal = self._right_pad_if_necessary(signal)
        # Transformation of mel spectrum
        signal = self.transformation(signal)
        return signal, label, audio_sample_path

    
    # Whether the signal should be cropped: If the number of picks > the set number -> cropping
    def _cut_if_necessary(self, signal):
        # print('_cut_if_necessary')
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]
        return signal
    
    
    # Whether the signal needs to be replenished: fill in 0 to the right to replenish, if the number of picking < the set number -> replenish
    def _right_pad_if_necessary(self, signal):
        length_signal = signal.shape[1]
        # print('_right_pad_if_necessary')
        if length_signal < self.num_samples:
            
            num_missing_samples = self.num_samples - length_signal
            last_dim_padding = (0, num_missing_samples)
            # last_dim_padding.to(self.device)
            
            signal = torch.nn.functional.pad(signal, last_dim_padding)

        return signal

    
    # Resetting the sampling frequency
    def _resample_if_necessary(self, signal, sr):
        # print('_resample_if_necessary')
        # If the actual sampling frequency is not the same as the set one, then only reset
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate).to(self.device)
            signal = resampler(signal)
            # signal = torchaudio.functional.resample(signal, sr, self.target_sample_rate)
            
        return signal


    # Change the dual channel of audio to single channel
    def _mix_down_if_necessary(self, signal):
        # print('_mix_down_if_necessary')
        # If the number of channels is greater than 1, the average value becomes a single channel.
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal

    # Splicing and extraction of audio paths
    def _get_audio_sample_path(self, index):
        # print('_get_audio_sample_path')
        fold = f"{self.annotations.iloc[index, -2]}"
        path = os.path.join(self.audio_dir, fold, self.annotations.iloc[
            index, 1])
        return path
    
    
    # 从Extracting tags from csv files
    def _get_audio_sample_label(self, index):
        # print('_get_audio_sample_label')
        return self.annotations.iloc[index, -1]
    

if __name__ == "__main__":
    ANNOTATIONS_FILE = "./features_30_sec_final.csv"
    AUDIO_DIR = "./GTZAN/genres_original"
    SAMPLE_RATE = 22050
    NUM_SAMPLES = 22050 * 5 # -> 1 second of audio
    plot = True

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device} device")

    mfcc = torchaudio.transforms.MFCC(
        sample_rate=SAMPLE_RATE,
        n_mfcc=40,
        log_mels=True
    )

    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=1024,
        # Window size
        hop_length=512,
        # Mel Frequency
        n_mels=64
    )

    # objects inside transforms module are callable!
    # ms = mel_spectrogram(signal)

    gtzan = GTZANDataset(
        ANNOTATIONS_FILE,
        AUDIO_DIR,
        mfcc,
        SAMPLE_RATE,
        NUM_SAMPLES,
        device
    )

    print(f"There are {len(gtzan)} samples in the dataset")

    if plot:
        signal, label, path = gtzan[666]
        print(f'path:{path}')
        signal = signal.cpu()
        print(signal.shape)
        
        plt.figure(figsize=(16, 8), facecolor="white")
        plt.imshow(signal[0,:,:], origin='lower')
        plt.autoscale(False)
        plt.xlabel("Time")
        plt.ylabel("Frequency")
        plt.colorbar()
        plt.axis('auto')
        plt.show()


In [None]:
content