In [22]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# read file
music_data = pd.read_csv('music.csv')

# we need to separate the 'genre' column from the rest as it is the output data
X = music_data.drop(columns = ['genre'])

# put genre into axis y for output
y = music_data['genre']

# split data, will return a tuple which we can unpack
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

# use decision tree model to provide predictions
model = DecisionTreeClassifier()
# model.fit(X, y) # model.fit(input, output)
model.fit(X_train, y_train) 
# predictions = model.predict([[21, 1], [22, 0]])
predictions = model.predict(X_test)

# calculate the accuracy of the predictions, usually 75% - 80% of data is for training and 20% is for testing
score = accuracy_score(y_test, predictions)
score

1.0

In [25]:
# persist and load models by not retraining your model every time you want to make a prediction
# so we can use the joblib function to export in binary the already trained data
import joblib
music_data = pd.read_csv('music.csv')
X = music_data.drop(columns = ['genre'])
y = music_data['genre']

model = DecisionTreeClassifier()
model.fit(X, y)

joblib.dump(model, 'music-recommender.joblib')

['music-recommender.joblib']

In [26]:
# then, if we need predictions, we can just load the model instead of retraining it with the same data
model = joblib.load('music-recommender.joblib')
predictions = model.predict([[22, 1]])
predictions

array(['HipHop'], dtype=object)

In [27]:
# see how DecisionTrees work
from sklearn import tree
music_data = pd.read_csv('music.csv')
X = music_data.drop(columns = ['genre'])
y = music_data['genre']

model = DecisionTreeClassifier()
model.fit(X, y)

tree.export_graphviz(model,
                     out_file = 'music_recommender.dot',
                     feature_names = ['age', 'gender'],
                     class_names = sorted(y.unique()),
                     label = 'all',
                     filled = True,
                     rounded = True)