In [1]:
%matplotlib notebook

In [2]:
import matplotlib as mpl
mpl.get_backend()

'nbAgg'

In [9]:

# yellowbrick.features.scatterplotmatrix
# Implementations of joint plots for univariate and bivariate analysis.
#
# Author:   Prema Damodaran Roman 
# Created:  
#
# Copyright (C) 2017 District Data Labs
# For license information, see LICENSE.txt
#
# ID: scatterplotmattix.py

##########################################################################
## Imports
##########################################################################
#from matplotlib.backends.backend_agg import FigureCanvasAgg

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from yellowbrick.features.base import DataVisualizer
from yellowbrick.exceptions import YellowbrickValueError
from yellowbrick.bestfit import draw_best_fit
from yellowbrick.utils import is_dataframe

class ScatterPlotMatrixVisualizer(DataVisualizer):
    """
    ScatterPlotMatrixVisualizer allows for a simultaneous visualization of the relationship
    between two variables and the distrbution of each individual variable.  The 
    relationship is plotted along the joint axis and univariate distributions
    are plotted on top of the x axis and to the right of the y axis.	
    """

    def __init__(self, ax=None, features=None, classes=None, 
                 nondiag_plot='scatter', nondiag_args=None, 
                 diag_plot='hist', diag_args=None,
                 size=10, ratio=5, space=.2, **kwargs):

        """
        Initialize the visualization with many of the options required
        in order to make most visualizations work.

        These parameters can be influenced later on in the visualization
        process, but can and should be set as early as possible.

        Parameters
        ----------

        ax: This is inherited from FeatureVisualizer but is defined within
            JointPlotVisualizer since there are three axes objects.

        feature: The name of the X variable
            If a DataFrame is passed to fit and feature is None, feature
            is selected as the column of the DataFrame.  There must be only
            one column in the DataFrame.

        target: The name of the Y variable
            If target is None and a y value is passed to fit then the target
            is selected from the target vector.

        joint_plot: The type of plot to render in the joint axis.  Currently,
            the choices are scatter and hex.  Use scatter for small datasets
            and hex for large datasets

        joint_args: Keyword arguments used for customizing the joint plot.
                Property        Description
                alpha           transparency
                facecolor       background color of the joint axis
                aspect          aspect ratio
                fit             used if scatter is selected for joint_plot to draw a 
                                best fit line - values can be True or False.
                                Uses Yellowbrick.bestfit
                estimator       Used if scatter is selected for joint_plot to determine
                                the type of best fit line to use.  Refer to 
                                Yellowbrick.bestfit for types of estimators that can be used.
                x_bins          used if hex is selected to set the number of bins for the x value
                y_bins          used if hex is selected to set the number of bins for the y value
                cmap            string or matplotlib cmap to colorize lines
                                Use either color to colorize the lines on a per class basis or
                                colormap to color them on a continuous scale.


        xy_plot: The type of plot to render along the x and y axes.
            Currently, the choice is hist

        xy_args: Keyword arguments used for customizing the x and y plots.
                Property        Description
                alpha           transparency
                facecolor_x     background color of the x axis
                facecolor_y     background color of the y axis
                bins            used to set up the number of bins for the hist plot
                histcolor_x     used to set the color for the histogram on the x axis
                histcolor_y     used to set the color for the histogram on the y axis

        size: Size of each side of the figure in inches.

        ratio: Ratio of joint axis size to the x and y axes height.

        space: Space between the joint axis and the x and y axes.

        kwargs: Keyword arguments passed to the super class.

        """

        #check matplotlib version - needs to be version 2.0.0	
        if mpl.__version__ == '2.0.0':
            pass 
        else:
            print('This Visualizer requires Matplotlib version 2.0.0. Please upgrade to  continue.')

        super(ScatterPlotMatrixVisualizer, self).__init__(ax, **kwargs)

        self.features_ = features
        self.classes_ = classes
        self.nondiag_plot = nondiag_plot
        self.nondiag_args = nondiag_args
        self.diag_plot = diag_plot
        self.diag_args = diag_args
        self.size = size
        self.ratio = ratio
        self.space = space


    def draw(self, X, y, **kwargs):  

        """       
        Sets up the layout for the joint plot
        draw calls draw_joint and draw_xy
        to render the visualizations
    
        """
        num_features = self.features_.size
        
        fig, axes = plt.subplots(num_features, num_features,
                               figsize=(self.size,self.size),
                               sharex="col", 
                               sharey="row",
                               squeeze=False)
        
        
        fig.subplots_adjust(hspace=self.space, wspace=self.space)
        
        #fig = plt.figure()
        #for i, x_feature in enumerate(self.features_):
        #    for j, y_feature in enumerate(self.features_):
        #        if(i == j):
        #            self.draw_diag(self, ax[i,j], x_feature, y)
        #        else:
        #            self.draw_nondiag(self, ax[i,j], x_feature, y_feature, X[x_feature], X[y_feature], y)
                    
        #       if j != 0:
        #            ax[i,j].yaxis.set_visible(False)
        #            plt.setp(ax[i,j].get_yticklabels(), visible=False)
        #        else:
        #            ax[i,j].set_ylabel(x_feature)
        
        #        if i != num_features - 1:
        #            ax[i,j].xaxis.set_visible(False)
        #            plt.setp(ax[i,j].get_xticklabels(), visible=False)
        #        else:
        #            ax[i,j].set_xlabel(y_feature)
        
        #self.fig = fig
        
        for i, x_feature in enumerate(self.features_):
            for j, y_feature in enumerate(self.features_):
                ax = axes[i, j]
                if(i == j):
                    self.draw_diag(self, ax, X[x_feature], y)
                else:
                    #xlim = boundaries_list[j]
                    #ylim = boundaries_list[i]
                    self.draw_nondiag(self, ax, x_feature, y_feature, X[x_feature], X[y_feature], y)
                    
                if j != 0:
                    ax.yaxis.set_visible(False)
                    plt.setp(ax.get_yticklabels(), visible=False)
                else:
                    ax.set_ylabel(x_feature)
        
                if i != num_features - 1:
                    ax.xaxis.set_visible(False)
                    plt.setp(ax.get_xticklabels(), visible=False)
                else:
                    ax.set_xlabel(y_feature)


        self.fig = fig
        #canvas = FigureCanvasAgg(self.fig)
        
    @staticmethod
    def draw_nondiag(self, ax, xlab, ylab, xvals, yvals, y):

        """       
        Draws the visualization for the joint axis

        """
        
        if self.nondiag_args is None:
            self.nondiag_args = {}

        #self.nondiag_args.setdefault("alpha", 0.4)
        #facecolor = self.joint_args.pop("facecolor", "#dddddd")
        #self.joint_ax.set_facecolor(facecolor)
        xmin = xvals.min()
        xmax = xvals.max()
        ymin = yvals.min()
        ymax = yvals.max()
        
        if self.nondiag_plot == "scatter":
            aspect = self.nondiag_args.pop("aspect", "auto")
            ax.set_aspect(aspect)
            #ax.set(xlim=xlim, ylim=ylim)
            ax.set(xlim=(0,xmax), ylim=(0,ymax))
            #ax.scatter(xvals, yvals, c=y, cmap='viridis', **self.nondiag_args)
            ax.scatter(xvals, yvals, **self.nondiag_args)
            
            fit = self.nondiag_args.pop("fit", True)
            if fit:
                estimator = self.nondiag_args.pop("estimator", "linear")
                draw_best_fit(xvals, yvals, ax, estimator)

        elif self.joint_plot == "hex":
            x_bins = self.joint_args.pop("x_bins", 50)
            y_bins = self.joint_args.pop("y_bins", 50)
            colormap = self.joint_args.pop("cmap", 'Blues')
            gridsize = int(np.mean([x_bins, y_bins]))

            xmin = X.min()
            xmax = X.max()
            ymin = y.min()
            ymax = y.max()

            ax.hexbin(X, y, **self.nondiag_args)
            #ax.hexbin(X, y, gridsize=gridsize, cmap=colormap, 
            #                     mincnt=1, **self.nondiag_args)
            #ax.axis([xmin, xmax, ymin, ymax])
        
        

    @staticmethod    
    def draw_diag(self, ax, vals, y):

        """       
        Draws the visualization for the x and y axes

        """
        
        if self.diag_args is None:
            self.diag_args = {}

        if self.diag_plot == "hist":
            classes_grouped = pd.DataFrame(vals).groupby(y)
            grouped_vals = []
            for label in self.classes_:
                grouped_vals.append(np.asarray(classes_grouped.get_group(label)))
            min_val = int(round(min(vals)))
            max_val = int(round(max(vals)))
            hist_bins = self.diag_args.pop("bins", 50)
            self.diag_args.setdefault("alpha", 0.4)
            histcolor = self.diag_args.pop("histcolor", "#6897bb")
            ax.hist(grouped_vals, 
                    bins=max_val + 1,
                    stacked=True,
                    range=(min_val,max_val))
            ax.set(xlim=(min_val,max_val))
 
    def poof(self, **kwargs):
        
        """       
            Creates the labels for the feature and target variables

        """

        #self.joint_ax.set_xlabel(self.feature)
        #self.joint_ax.set_ylabel(self.target)
        self.finalize(**kwargs)

    def finalize(self, **kwargs):

        """
        Finalize executes any subclass-specific axes finalization steps.
        The user calls poof and poof calls finalize.

        Parameters
        ----------
        kwargs: generic keyword arguments.

        """

        #plt.setp(self.x_ax.get_xticklabels(), visible=False)
        #plt.setp(self.y_ax.get_yticklabels(), visible=False)

        #plt.setp(self.x_ax.yaxis.get_majorticklines(), visible=False)
        #plt.setp(self.x_ax.yaxis.get_minorticklines(), visible=False)
        #plt.setp(self.y_ax.xaxis.get_majorticklines(), visible=False)
        #plt.setp(self.y_ax.xaxis.get_minorticklines(), visible=False)
        #plt.setp(self.x_ax.get_yticklabels(), visible=False)
        #plt.setp(self.y_ax.get_xticklabels(), visible=False)
        #self.x_ax.yaxis.grid(False)
        #self.y_ax.xaxis.grid(False)
        self.fig.suptitle("Scatter Plot Matrix")



In [4]:
import pandas as pd
from sklearn import datasets

iris = datasets.load_iris()
species_names = []
for val in iris.target:
    if val==0:
        species_names.append('setosa')
    elif val==1:
        species_names.append('versicolor')
    else:
        species_names.append('virginica')
iris_vals = np.column_stack([iris.data, species_names])
col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
iris_data = np.row_stack([col_names, iris_vals])

iris_df = pd.DataFrame(data=iris_vals, columns = col_names)


In [5]:
cols = iris_df.columns.values
features = cols[0:-1]
target = cols[-1]
X = iris_df[features].astype(float)
y = iris_df[target]
classes = set(y)

In [10]:
visualizer = ScatterPlotMatrixVisualizer(features=features, classes=classes)
visualizer.fit(X, y)   # Fit the data to the visualizer
g = visualizer.poof() 

<IPython.core.display.Javascript object>

