In [10]:
from PivotTree import *
from RuleTree import *

In [14]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import ColumnDataSource, HoverTool, Legend
from bokeh.layouts import column
from bokeh.models import ColumnDataSource, Label, LabelSet, Range1d

output_notebook()

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import load_breast_cancer


# Load the breast cancer dataset
breast_cancer = load_breast_cancer()
n_classes = 2


feat_1 = 7
feat_2 = 8

X = breast_cancer.data[:, [feat_1, feat_2]]
y = breast_cancer.target

np.random.seed(20)

idx = np.arange(X.shape[0])
np.random.shuffle(idx)
X = X[idx]
y = y[idx]

# Normalize the features
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std

# Define colors
area_colors = ['#ff9999', '#9999ff', '#9999ff']

label_cmap = ListedColormap(area_colors)


clf = PivotTree(max_depth=2, random_state=0, allow_oblique_splits=False, force_oblique_splits=False)
clf.fit(X, y)
print(clf.print_tree())

plot_step = 0.02
x_min, x_max = X[:, 0].min() - 0.2, X[:, 0].max() + 0.2
y_min, y_max = X[:, 1].min() - 0.2, X[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
                     np.arange(y_min, y_max, plot_step))

# Predict on the mesh grid
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)


# Prepare the data for Bokeh
#darker_red = '#ff6666', '6666ff'


labels = [breast_cancer.target_names[label] for label in y]

darker_red = '#ff6666'
darker_blue = '#6666ff'

darker_red = '#cc3333'
darker_blue = '#333399'


triangle_data = dict(x=X[y == 0, 0], y=X[y == 0, 1], color=[area_colors[0]] * sum(y == 0), label=[breast_cancer.target_names[0]] * sum(y == 0))
circle_data = dict(x=X[y == 1, 0], y=X[y == 1, 1], color=[area_colors[1]] * sum(y == 1), label=[breast_cancer.target_names[1]] * sum(y == 1))

circle_source = ColumnDataSource(circle_data)
triangle_source = ColumnDataSource(triangle_data)


p = figure(
           x_axis_label=breast_cancer.feature_names[feat_1],
           y_axis_label=breast_cancer.feature_names[feat_2],
           frame_width=800,
           frame_height=600,
           x_range=(x_min, x_max),  # Set x-axis range
    y_range=(y_min, y_max)   # Set y-axis range
)

p.xaxis.axis_label_text_font_size = "36pt"
p.yaxis.axis_label_text_font_size = "36pt"
p.xaxis.axis_label_text_font_style = "bold"
p.yaxis.axis_label_text_font_style = "bold"


p.xaxis.major_label_text_font_size = "0pt"
p.yaxis.major_label_text_font_size = "0pt"
p.xaxis.major_tick_line_color = None
p.yaxis.major_tick_line_color = None
p.xaxis.minor_tick_line_color = None
p.yaxis.minor_tick_line_color = None


circles = p.circle('x', 'y', size=10, color='color', source=circle_source, legend_label='Benign', line_color='#6666ff', alpha = 1.0)
triangles = p.triangle('x', 'y', size=10, color='color', source=triangle_source, legend_label='Malignant', line_color='#ff6666', alpha = 1.0)


pivot_points = [(379, 379), (351, 351) ]


p.circle(X[pivot_points[0][0] , 0], X[pivot_points[0][0], 1], size=30, color=darker_blue, legend_label=f'Pivot {123}', line_color = 'black')
p.triangle(X[pivot_points[1][0], 0], X[pivot_points[1][0], 1], size=30, color=darker_red, legend_label=f'Pivot {269}', line_color = 'black')


hover = HoverTool()
hover.tooltips = [("Index", "$index"), ("(x,y)", "($x, $y)"), ("Label", "@label")]
p.add_tools(hover)

p.legend.location = "top_right"
p.legend.title = ''
p.legend.label_text_font_size = "18pt"
p.legend.glyph_height = 20
p.legend.glyph_width = 20
p.legend.spacing = 10  # Adjust spacing between items
p.legend.padding = 10  # Add padding around legend box
p.legend.background_fill_alpha = 0.5 # Set legend box transparency


p.xgrid.visible = False
p.ygrid.visible = False

# Remove automatic legend items
p.legend.items = [
    ("Benign", [circles]),
    ("Malignant", [triangles])
]


for pivot_idx, actual_idx in pivot_points:
    co_lor = '#333399' if pivot_idx == (379) else '#cc3333'
    # p.line([X[pivot_idx, 0], nearest_x], [X[pivot_idx, 1], nearest_y], line_width=6, line_color= co_lor )




area_colors = ['#ff9999', '#9999ff', '#9999ff']


p.image(image=[Z], x=x_min, y=y_min, dw=x_max - x_min, dh=y_max - y_min, palette=area_colors, alpha=0.4)


citation1 = Label(x=235, y=150, x_units='screen', y_units='screen',
                 text='pivot 1',
                 border_line_color='black', border_line_alpha=0.0,
                 background_fill_color='white', background_fill_alpha=0.0, text_font_size = '28pt')


p.add_layout(citation1)

citation0 = Label(x=600, y=240, x_units='screen', y_units='screen',
                 text='pivot 0',
                 border_line_color='black', border_line_alpha=0.0,
                 background_fill_color='white', background_fill_alpha=0.0, text_font_size = '28pt')


p.add_layout(citation0)



show(p)


[(True, ['node_id: 0  pivot: 351'], [1.0], 3.6068702936172485, False, 0), (True, ['node_id: 1  pivot: 379'], [1.0], 0.3927917182445526, False, 1), (False, 1, 18, 0.03163444639718805, 3, 2), (False, 0, 194, 0.3409490333919156, 4, 2), (False, 1, 357, 0.6274165202108963, 2, 1)]
|-+ if node_id: 0  pivot: 351 <= 3.61:
  |-+ if node_id: 1  pivot: 379 <= 0.39:
    |--> label: 1 (18, 0.03)
    |--> label: 0 (194, 0.34)
  |--> label: 1 (357, 0.63)

