In [10]:
# Preparing the Data
# Learning and Predicting
# Calculating the Accuracy

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


music_data = pd.read_csv('music.csv') # read data from .csv file
X = music_data.drop(columns= ['genre' ]) # input dataset
y = music_data['genre']  # output dataset
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2)

model = DecisionTreeClassifier() # creating model, instance of DecisionTreeClassifier class
model.fit(X_train,y_train) # train the model 
predictions = model.predict(X_test) # make prediction 

score = accuracy_score(y_test, predictions)
score

['music-recommender.joblib']

In [11]:
# persisting Models- once build and train the model, then save it to a file. Next time during making predictions, 
# simply load the model from the file and ask it to make predictions  

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import joblib

# music_data = pd.read_csv('music.csv') # read data from .csv file
# X = music_data.drop(columns= ['genre' ]) # input dataset
# y = music_data['genre']  # output dataset

# model = DecisionTreeClassifier() # creating model, instance of DecisionTreeClassifier class
# model.fit(X,y) # train the model 

# joblib.dump(model,'music-recommender.joblib') # saving and loading model
model = joblib.load('music-recommender.joblib')
predictions = model.predict([[21,1]]) # make prediction 
predictions



array(['HipHop'], dtype=object)

In [12]:
# Visualizing a Decision Tree
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

music_data = pd.read_csv('music.csv') # read data from .csv file
X = music_data.drop(columns= ['genre' ]) # input dataset
y = music_data['genre']  # output dataset

model = DecisionTreeClassifier() # creating model, instance of DecisionTreeClassifier class
model.fit(X,y) # train the model 

tree.export_graphviz (model, out_file='music-recommender.dot',
                      feature_names=['age','gender'],
                      class_names=sorted(y.unique()),
                      label='all',
                      rounded=True,
                      filled=True)
