In [1]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

Read data

In [2]:
df = pd.read_csv("../data/music_genre.csv")
df.head(5)

Unnamed: 0,instance_id,artist_name,track_name,popularity,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,mode,speechiness,tempo,obtained_date,valence,music_genre
0,32894.0,Röyksopp,Röyksopp's Night Out,27.0,0.00468,0.652,-1.0,0.941,0.792,A#,0.115,-5.201,Minor,0.0748,100.889,4-Apr,0.759,Electronic
1,46652.0,Thievery Corporation,The Shining Path,31.0,0.0127,0.622,218293.0,0.89,0.95,D,0.124,-7.043,Minor,0.03,115.002,4-Apr,0.531,Electronic
2,30097.0,Dillon Francis,Hurricane,28.0,0.00306,0.62,215613.0,0.755,0.0118,G#,0.534,-4.617,Major,0.0345,127.994,4-Apr,0.333,Electronic
3,62177.0,Dubloadz,Nitro,34.0,0.0254,0.774,166875.0,0.7,0.00253,C#,0.157,-4.498,Major,0.239,128.014,4-Apr,0.27,Electronic
4,24907.0,What So Not,Divide & Conquer,32.0,0.00465,0.638,222369.0,0.587,0.909,F#,0.157,-6.266,Major,0.0413,145.036,4-Apr,0.323,Electronic


Check distribution of labels

In [3]:
df['music_genre'].value_counts()

Rock           5000
Classical      5000
Hip-Hop        5000
Rap            5000
Electronic     5000
Anime          5000
Blues          5000
Country        5000
Alternative    5000
Jazz           5000
Name: music_genre, dtype: int64

Split data into train and test sets

In [4]:
X = df.drop('music_genre', axis=1)
y = df['music_genre']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

display(X_train.shape)
display(X_test.shape)
display(y_train.shape)
display(y_test.shape)

(40004, 17)

(10001, 17)

(40004,)

(10001,)

Save data into CSV files, ignoring index

In [5]:
df_train = pd.concat((X_train, y_train), axis=1)
df_test = pd.concat((X_test, y_test), axis=1)

df_train.to_csv("../data/music_genre_train.csv", index=False)
df_test.to_csv("../data/music_genre_test.csv", index=False)

Import data again and check if everything is fine

In [6]:
df_train_1 = pd.read_csv("../data/music_genre_train.csv")
df_train_1.head(5)

Unnamed: 0,instance_id,artist_name,track_name,popularity,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,mode,speechiness,tempo,obtained_date,valence,music_genre
0,35466.0,Linkin Park,Wth>You (feat. Aceyalone),38.0,0.00696,0.58,252307.0,0.982,5e-06,E,0.165,-3.428,Minor,0.0729,98.189,4-Apr,0.51,Alternative
1,21927.0,Matthew Dear,Bad Ones (feat. Tegan and Sara),52.0,0.00693,0.769,276983.0,0.511,0.0885,B,0.223,-8.357,Major,0.0402,104.003,4-Apr,0.0399,Alternative
2,91335.0,La-33,La Pantera Mambo,47.0,0.474,0.782,-1.0,0.791,0.0012,C,0.0555,-6.321,Major,0.0383,99.992,4-Apr,0.819,Jazz
3,91513.0,Sheena Ringo,あおぞら,35.0,0.643,0.701,255800.0,0.762,0.612,B,0.432,-6.117,Major,0.042,130.105,4-Apr,0.641,Anime
4,80060.0,MGMT,"Of Moons, Birds & Monsters",52.0,0.0218,0.598,286720.0,0.899,0.0137,C#,0.0814,-3.861,Minor,0.0424,124.061,4-Apr,0.403,Rock


In [7]:
df_test_1 = pd.read_csv("../data/music_genre_test.csv")
df_test_1.head(5)

Unnamed: 0,instance_id,artist_name,track_name,popularity,acousticness,danceability,duration_ms,energy,instrumentalness,key,liveness,loudness,mode,speechiness,tempo,obtained_date,valence,music_genre
0,87766.0,XXXTENTACION,Boost!,58.0,0.0642,0.723,77601.0,0.328,0.0,C#,0.123,-11.121,Major,0.477,124.087,4-Apr,0.382,Rap
1,29741.0,hachi,バウムクーヘン - ORIGINAL,17.0,0.471,0.649,213200.0,0.626,0.00021,G,0.0882,-5.551,Major,0.0516,125.067,4-Apr,0.695,Anime
2,84644.0,Sik World,Mental Issues,58.0,0.066,0.907,209907.0,0.476,0.0,C#,0.0781,-10.867,Major,0.344,135.71200000000002,4-Apr,0.341,Rap
3,40883.0,Young Nudy,Friday,48.0,0.17,0.796,178987.0,0.484,0.0,A#,0.151,-7.131,Minor,0.149,137.901,4-Apr,0.158,Hip-Hop
4,66562.0,Chris Smither,Hold On I,28.0,0.858,0.529,243645.0,0.496,0.0163,E,0.69,-9.078,Major,0.0285,102.59,4-Apr,0.504,Blues


In [8]:
display(df_train.shape)
display(df_test.shape)

(40004, 18)

(10001, 18)

Check distribution of labels

In [9]:
df_train_1['music_genre'].value_counts()

Country        4040
Jazz           4034
Classical      4018
Alternative    4017
Electronic     3991
Rock           3991
Hip-Hop        3991
Blues          3983
Rap            3968
Anime          3967
Name: music_genre, dtype: int64

In [10]:
df_test_1['music_genre'].value_counts()

Anime          1033
Rap            1032
Blues          1017
Hip-Hop        1009
Electronic     1009
Rock           1009
Alternative     983
Classical       982
Jazz            966
Country         960
Name: music_genre, dtype: int64