In [1]:
import numpy as np
import pandas as pd

from lets_plot import *

import sys; sys.path.insert(0, "..")
from residual2 import *

In [2]:
LetsPlot.setup_html()

In [3]:
_KIND_DEF = "scatter"
_MARGINAL_SIZE_DEF = .1

def _get_geom(kind):
    if kind in ['point', 'scatter', 'reg']:
        return 'point'
    if kind in ['contour', 'density', 'kde']:
        return 'contour'
    if kind in ['tile', 'hist']:
        return 'tile'
    raise Exception("Unknown joint plot kind '{0}'".format(kind))

def _get_marginal(kind, color_by=None):
    if color_by is not None or kind in ['contour', 'density', 'kde']:
        return 'dens'
    if kind in ['point', 'scatter', 'reg', 'tile', 'hist']:
        return 'hist'
    raise Exception("Unknown joint plot kind '{0}'".format(kind))

def joint_plot(data, x, y, *,
               kind=None,
               color_by=None,
               marginal_size=None,
               color=None, size=None, alpha=None):
    my_kind = kind or _KIND_DEF
    p = residual_plot(data, x=x, y=y,
                      method='none',
                      geom=_get_geom(my_kind),
                      color=color, size=size, alpha=alpha,
                      color_by=color_by,
                      hline=False,
                      marginal="{0}:tr:{1}".format(_get_marginal(my_kind, color_by), marginal_size or _MARGINAL_SIZE_DEF))
    if kind == 'reg':
        p += geom_smooth()
    p += theme_classic()
    return p

In [4]:
df = pd.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv")
print(df.shape)
df.head()

(344, 7)


Unnamed: 0,species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex
0,Adelie,Torgersen,39.1,18.7,181.0,3750.0,MALE
1,Adelie,Torgersen,39.5,17.4,186.0,3800.0,FEMALE
2,Adelie,Torgersen,40.3,18.0,195.0,3250.0,FEMALE
3,Adelie,Torgersen,,,,,
4,Adelie,Torgersen,36.7,19.3,193.0,3450.0,FEMALE


In [5]:
joint_plot(df, "bill_length_mm", "bill_depth_mm")

In [6]:
joint_plot(df, "bill_length_mm", "bill_depth_mm", color_by="species")

In [7]:
joint_plot(df, "bill_length_mm", "bill_depth_mm", color_by="species", kind="kde")

In [8]:
joint_plot(df, "bill_length_mm", "bill_depth_mm", kind="reg")

In [9]:
joint_plot(df, "bill_length_mm", "bill_depth_mm", kind="hist")

In [10]:
joint_plot(df, "bill_length_mm", "bill_depth_mm", marginal_size=.25)