In [5]:
import findspark
from pyspark.sql import SparkSession

findspark.init()

spark = SparkSession.builder \
        .master("local[*]") \
        .appName("arvoredecisao") \
        .getOrCreate()

from pyspark.ml.classification import DecisionTreeClassifier
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import mean_squared_error, mean_absolute_error
from pyspark.ml.feature import Normalizer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
plt.style.use('ggplot')

class ArvoreDecisao():
    
    def __init__(self, train_df, test_df):
        self.train = train_df
        self.test = test_df
        
    def predicoes():
        model = DecisionTreeClassifier(labelCol='label', featuresCol='normFeatures')
        trained_model = model.fit(train_df)
        
        test_predictions = trained_model.transform(test_df)
        
        test_df_count_1 = test_df.filter(test_df['label'] == 1).count()
        test_df_count_0 = test_df.filter(test_df['label'] == 0).count()
        test_df_count_1, test_df_count_0
        
        fp = test_predictions.filter(test_predictions['label'] == 0).filter(
        test_predictions['prediction'] == 1).select(
        ['label','prediction','probability'])
        print("Falsos positivos: ", fp.count())
        
        fn = test_predictions.filter(
        test_predictions['label'] == 1).filter(
        test_predictions['prediction'] == 0).select(
        ['label','prediction','probability'])
        print("Falsos negativos: ", fn.count())
        
        predictionAndLabels = test_predictions.select("prediction", "label")
        evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
        print("Acurácia = " + str(evaluator.evaluate(predictionAndLabels)))
        
    def matriz_confusao():
        real = np.array(test_df.select("label").collect())
        predito = np.array(test_predictions.select("prediction").collect())
        
        cm = confusion_matrix(real, predito)
        fig, ax = plt.subplots()
        sns.heatmap(pd.DataFrame(cm), annot=True, cmap="RdGy" ,fmt='g')
        ax.xaxis.set_label_position("top")
        plt.tight_layout()
        plt.title('Matriz de confusão', y=1.1)
        plt.ylabel('Label real')
        plt.xlabel('Label predita')
        #plt.savefig('arvoredecisao.png')
        
    def metricas():
        classes = ['DOWN', 'UP']
        target_names = ["Classe {}".format(i) for i in classes]
        print(classification_report(real, predito, target_names=target_names))
        
        RMSE = mean_squared_error(real, predito, squared=False)
        print(RMSE)
        
        MAE = mean_absolute_error(real, predito)
        print(MAE)