# Classes for data analytics
Provides an abstract class for training models and vlidate results.

In [None]:
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, confusion_matrix
from sklearn.model_selection import cross_val_score
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import itertools


class Predictor(object):
    
    def __init__(self, model, training_set, labels):
        """
        :param object model: a sklearn predictive model
        :param nparray training_set: data for training
        :labels string-like: list of expeced labels for data
        """
        self.model = model
        self.T = training_set
        self.L = labels
        self.model.fit(self.T, self.L)
    
    def mse(self):
        p = self.model.predict(self.T)
        return np.sqrt(mean_squared_error(self.L, p))
    
    def cross_validate(self):
        scores = cross_val_score(self.model, self.T, self.L, 
                                scoring='neg_mean_squared_error', cv=10)
        return np.sqrt(-scores)
    
    def confusion_matrix(self):
        p = self.model.predict(self.T)
        cms = confusion_matrix(self.L, p, labels=list(set(self.L)))
        return cms
    
    @staticmethod
    def cm_plot(ax, classes, CM, title, figure):
        im = ax.imshow(CM, interpolation='nearest', cmap=plt.cm.Blues)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        figure.colorbar(im, cax=cax, orientation='vertical')
        tick_marks = np.arange(len(classes))
        ax.set_xticks(tick_marks)
        ax.set_xticklabels(classes, rotation=90, fontsize=12)
        ax.set_yticks(tick_marks)
        ax.set_yticklabels(classes, rotation=0, fontsize=12)
        ax.set_title(title, fontsize=16)
        thresh = CM.max() / 2.
        for i, j in itertools.product(range(CM.shape[0]), range(CM.shape[1])):
            ax.text(j, i, CM[i, j], horizontalalignment="center",
                     color="white" if CM[i, j] > thresh else "black", fontsize=12)
        ax.set_ylabel('True label', fontsize=16)
        ax.set_xlabel('Predicted label', fontsize=16)