In [3]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        continue
        #print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [4]:
# 랜덤 시드 고정
import random
import librosa
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

seed = 42

random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [17]:
rootpath = '***'

In [18]:
# 학습할 음악들의 장르 정답 csv 파일
train_info_csv = pd.read_csv(rootpath + '/train_labels.csv')

# 제출에 사용할 csv 파일
submit = pd.read_csv(rootpath + '/submit.csv')
test_info_csv = submit['id']

In [6]:
# label
train_info_csv['genre'].unique()

array(['rock', 'country', 'metal', 'hiphop', 'pop', 'classical', 'disco',
       'reggae', 'blues', 'jazz'], dtype=object)

In [7]:
def extract_rhythm_features(file_path):
    # 반환할 feature list 선언
    feature = []
    y, sr = librosa.load(file_path, sr=22050)
    onset_envelope = librosa.onset.onset_strength(y=y, sr=sr)
    
    # autocorrelation tempogram
    tempogram = librosa.feature.tempogram(onset_envelope=onset_envelope, sr=sr)  # (384, 1293)
    #print(tempogram.shape)
    tempogram_feature = np.abs(np.mean(tempogram, axis=1))   # (384, 1293) -> (384, ), 복소수를 없애기 위해 절대값 처
    #print(tempogram_feature.shape)
    
    feature = tempogram_feature
    
    return feature

In [8]:
def extract_spectral_features(file_path):
    # 반환할 feature list 선언
    feature = []
    y, sr = librosa.load(file_path, sr=22050)

    spectrogram = np.abs(librosa.stft(y))
    power_spectrogram = spectrogram ** 2
    melspectrogram = librosa.feature.melspectrogram(S=power_spectrogram)
    melspectrogram_db = librosa.power_to_db(melspectrogram)
    
    # chromagram
    chromagram = librosa.feature.chroma_stft(S=power_spectrogram) # (12, 1293)
    chromagram_feature = np.mean(chromagram, axis=1)  # (12, 1293) -> (12, )
     
    # mfcc
    mfcc = librosa.feature.mfcc(S=melspectrogram_db)   # (20, 1293)
    mfcc_feature = np.mean(mfcc, axis=1) # (20, )
    
    feature = np.concatenate((chromagram_feature, mfcc_feature))

    return feature

In [9]:
def feature_loader(data_info, split=None, rootpath=None, domain=None):
    split = split.upper()
    info_dict = {}
    
    if split=='TRAIN':
        train_path = os.path.join(rootpath, 'train')
        file_list = data_info['id']
        label_list = data_info['genre']
        
        for file, label in zip(tqdm(file_list), label_list):
            # 손상된 wav파일 제외
            if file == 'train_412.wav':
                continue
                
            file_dict = {}     
            file_dict['label'] = label
            
            file_path = os.path.join(train_path, file)
            if domain == 'spectral':
                features = extract_spectral_features(file_path)
            elif domain == 'rhythm':
                features = extract_rhythm_features(file_path)
            else:
                raise Exception("Check domain")
                
            file_dict['features'] = features
            info_dict[file] = file_dict
        
        return info_dict
        
    elif split=='TEST':
        test_path = os.path.join(rootpath, 'test')
        file_list = data_info
        
        for file in tqdm(file_list):
            file_dict = {}
            file_path = os.path.join(test_path, file)
            
            if domain == 'spectral':
                features = extract_spectral_features(file_path)
            elif domain == 'rhythm':
                features = extract_rhythm_features(file_path)
            else:
                raise Exception("Check domain")
                
            file_dict['features'] = features
            info_dict[file] = file_dict
            
        return info_dict
    
    else:
        raise Exception("Check split")

In [10]:


# 'spectral' or 'rhythm'
domain = 'rhythm'

# 선택한 domain의 feature 추출, dictionary 반환 받기
train_data = feature_loader(train_info_csv, split='train', rootpath=rootpath, domain=domain)
test_data = feature_loader(test_info_csv, split='test', rootpath=rootpath, domain=domain) 

100%|██████████| 800/800 [03:09<00:00,  4.22it/s]
100%|██████████| 200/200 [00:43<00:00,  4.61it/s]


In [11]:
x_train = []
y_train = []

for key in train_data.keys():
    x_train.append(train_data[key]['features'])
    y_train.append(train_data[key]['label'])
    
x_train = np.asarray(x_train)
y_train = np.asarray(y_train) 

In [12]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
y_train = le.fit_transform(y_train)

In [13]:
x_test = []

for key in test_data.keys():
    x_test.append(test_data[key]['features'])
    #print(test_data[key]['features'].shape)
    
x_test = np.asarray(x_test)

In [14]:
from sklearn.ensemble import RandomForestClassifier
rfc = RandomForestClassifier(random_state=seed)
rfc.fit(x_train, y_train)
pred_rfc = rfc.predict(x_test)

In [15]:
submit['genre'] = le.inverse_transform(pred_rfc)
submit.to_csv(f'{domain}_feature_baseline.csv', index=False)

In [16]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

train_pred = rfc.predict(x_train)

print(f"Accuracy score: {accuracy_score(y_train, train_pred)*100}")
print(confusion_matrix(y_train, train_pred))

Accuracy score: 99.87484355444305
[[80  0  0  0  0  0  0  0  0  0]
 [ 0 80  0  0  0  0  0  0  0  0]
 [ 0  0 80  0  0  0  0  0  0  0]
 [ 0  0  0 80  0  0  0  0  0  0]
 [ 0  0  0  0 80  0  0  0  0  0]
 [ 0  0  0  0  0 79  0  0  0  0]
 [ 0  0  0  0  0  0 79  0  0  1]
 [ 0  0  0  0  0  0  0 80  0  0]
 [ 0  0  0  0  0  0  0  0 80  0]
 [ 0  0  0  0  0  0  0  0  0 80]]
