<a href="https://colab.research.google.com/github/ahmed-boutar/interpreting-rule-based-models/blob/main/interpreting_rule_based_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [117]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, accuracy_score
from graphviz import Digraph

from imodels import OptimalRuleListClassifier, SlipperClassifier, OneRClassifier
import os

## Dataset Description

The dataset I will be using to train the model is the breast cancer dataset, provided in the scikit-learn library. The dataset titled, Breast Cancer Wisconsin (Diagnostic), is a popular dataset in machine learning, particularly for binary classification tasks. In this case, the target feature is whether or not the cancer tumor is benign or malignant. (Source: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_breast_cancer.html)

#### Provenance 

This dataset contains features computed from digitized images of samples of breast mass tissue. The dataset is used to predict whether a breast tumor is malignant or benign based on various characteristics of the cell nuclei present in the images. 

#### Authors & License 

The Breast Cancer Wisconsin (Diagnostic) dataset is part of scikit-learn, which is distributed under the *BSD 3-Clause license*, allowing it to be freely used for academic, commercial, and personal projects (provided the original copyright notice and the BSD 3-Clause license text are included)

The dataset was created by Dr. William H. Wolberg, W. Nick Street, and Olvi L. Mangasarian at the University of Wisconsin

#### Overview 
The dataset includes 30 features (all numerical) such as radius, texture, perimeter, area of each cell's nucleus. These features are computed for each cell nucleus, and the mean, standard error, and "worst" (largest) values are calculated for each feature.


In [118]:
breast_cancer = load_breast_cancer()
df = pd.DataFrame(breast_cancer.data, columns=breast_cancer.feature_names)
# Add the target variable, where 1 is benign and 0 is malignant
df['diagnosis'] = breast_cancer.target
df.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,diagnosis
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,17.33,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,25.53,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,26.5,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0


In [119]:
df.describe()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,diagnosis
count,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,...,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0,569.0
mean,14.127292,19.289649,91.969033,654.889104,0.09636,0.104341,0.088799,0.048919,0.181162,0.062798,...,25.677223,107.261213,880.583128,0.132369,0.254265,0.272188,0.114606,0.290076,0.083946,0.627417
std,3.524049,4.301036,24.298981,351.914129,0.014064,0.052813,0.07972,0.038803,0.027414,0.00706,...,6.146258,33.602542,569.356993,0.022832,0.157336,0.208624,0.065732,0.061867,0.018061,0.483918
min,6.981,9.71,43.79,143.5,0.05263,0.01938,0.0,0.0,0.106,0.04996,...,12.02,50.41,185.2,0.07117,0.02729,0.0,0.0,0.1565,0.05504,0.0
25%,11.7,16.17,75.17,420.3,0.08637,0.06492,0.02956,0.02031,0.1619,0.0577,...,21.08,84.11,515.3,0.1166,0.1472,0.1145,0.06493,0.2504,0.07146,0.0
50%,13.37,18.84,86.24,551.1,0.09587,0.09263,0.06154,0.0335,0.1792,0.06154,...,25.41,97.66,686.5,0.1313,0.2119,0.2267,0.09993,0.2822,0.08004,1.0
75%,15.78,21.8,104.1,782.7,0.1053,0.1304,0.1307,0.074,0.1957,0.06612,...,29.72,125.4,1084.0,0.146,0.3391,0.3829,0.1614,0.3179,0.09208,1.0
max,28.11,39.28,188.5,2501.0,0.1634,0.3454,0.4268,0.2012,0.304,0.09744,...,49.54,251.2,4254.0,0.2226,1.058,1.252,0.291,0.6638,0.2075,1.0


In [120]:
print(df['diagnosis'].value_counts())

diagnosis
1    357
0    212
Name: count, dtype: int64


#### Modeling

In [121]:
X = pd.DataFrame(breast_cancer.data, columns=breast_cancer.feature_names)
y = breast_cancer.target

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

All of the models used here for the classification of the breast cancer diagnosis are provided through the imodels library. (Souce: https://github.com/csinva/imodels?tab=readme-ov-file)

## One R 

This algorithm is often used as a benchmark for other methods. It is considered one of the simplest rule-based classification algorithms. 

From all the features, one R selects the one that carries the most information about the outcome of interest and creates decision rules from this feature (based on a **single feature**)

In [122]:
oneR_model = OneRClassifier()
oneR_model.fit(X_train, y_train, feature_names=breast_cancer.feature_names)

In [123]:
# Make predictions on the test set
y_pred = oneR_model.predict(X_test)

# Evaluate the model
acc = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
print(f'Accuracy: {acc:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')

Accuracy: 0.67
Precision: 0.68
Recall: 0.89


## Optimal Rule List (CORELS)

In [124]:
optimal_rule_list_model = OptimalRuleListClassifier()
optimal_rule_list_model.fit(X_train, y_train, feature_names=breast_cancer.feature_names)



[{'col': 'worst radius',
  'index_col': 20,
  'cutoff': 16.82,
  'val': 0.6263736263736264,
  'flip': True,
  'val_right': 0.9111842105263158,
  'num_pts': 455,
  'num_pts_right': 304},
 {'col': 'texture error',
  'index_col': 11,
  'cutoff': 0.4757,
  'val': 0.052980132450331126,
  'flip': True,
  'val_right': 1.0,
  'num_pts': 151,
  'num_pts_right': 5},
 {'col': 'worst concavity',
  'index_col': 26,
  'cutoff': 0.1932,
  'val': 0.02054794520547945,
  'flip': True,
  'val_right': 0.6,
  'num_pts': 146,
  'num_pts_right': 5},
 {'val': 0, 'num_pts': 141}]

In [125]:
# Make predictions on the test set
y_pred = optimal_rule_list_model.predict(X_test)
# Evaluate the model
acc = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
print(f'Accuracy: {acc:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')

Accuracy: 0.63
Precision: 0.63
Recall: 1.00


## SLIPPER

In [126]:
slipper_model = SlipperClassifier()
slipper_model.fit(X_train, y_train, feature_names=breast_cancer.feature_names)

In [127]:
# Make predictions on the test set
y_pred = slipper_model.predict(X_test)

# Evaluate the model
acc = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
print(f'Accuracy: {acc:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')


Accuracy: 0.63
Precision: 0.63
Recall: 1.00


## Creating graphs to visualize the models' outputs

In [128]:
# Used a combination of this documentation https://networkx.org/documentation/stable/reference/classes/digraph.html
# And some help from Claude to figure out the correct way to create the graph
def create_rule_graph(model_name, model):
    # Create a new directed graph
    graph = Digraph(comment=f'Decision Rule Visualization for {model_name}')
    graph.attr(rankdir='TB', size='8,8')

    rules = model.rules_
    
    # Add nodes and edges based on the rules
    for i, rule in enumerate(rules):
        node_id = f"node_{i}"
        try:
            # Added this exception to work around the default rule as the format made it harder to include 
            # in the Digraph
            graph.node(node_id, f"{rule['col']}\n≤ {rule['cutoff']:.5f}")
        except KeyError:
            break

        
        # Add left (False) branch
        left_id = f"leaf_{i}_left"
        graph.node(left_id, f"Value: {rule['val']:.4f}\nPoints: {rule['num_pts'] - rule['num_pts_right']}")
        graph.edge(node_id, left_id, label='False')
        
        # Add right (True) branch
        right_id = f"leaf_{i}_right"
        graph.node(right_id, f"Value: {rule['val_right']:.4f}\nPoints: {rule['num_pts_right']}")
        graph.edge(node_id, right_id, label='True')
        
        # Connect to the next rule if it exists
        if i < len(rules) - 1:
            graph.edge(left_id, f"node_{i+1}", style='dashed')

    return graph

In [129]:
# Function to save the graph as a dot file to display 
# Used this website https://www.devtoolsdaily.com/graphviz/ to visualize the outputs 
# Just copy pasted the output of the .dot file into the visualization website
def save_graph(dot, filename, format='pdf', graphviz_path=None):
    if graphviz_path:
        dot.engine = os.path.join(graphviz_path, 'dot')
    
    dot.save(f'{filename}.dot')
    print(f"DOT file saved as '{filename}.dot'")
        

In [130]:
# Try to save the graph
oneR_graph = create_rule_graph('One R', oneR_model)
save_graph(oneR_graph, 'oneR_box_diagram')

DOT file saved as 'oneR_box_diagram.dot'


In [131]:
optimal_rule_list_graph= create_rule_graph('Corels', optimal_rule_list_model)
save_graph(optimal_rule_list_graph, 'CORELS_diagram')

DOT file saved as 'CORELS_diagram.dot'


In [132]:
dot = Digraph(comment='Slipper Model Decision Tree')
dot.attr(rankdir='TB', size='12,12')

rules = slipper_model.rules_
formatted_rules = []
#format the numbers in the rules to have 3 digits after the decimal
for i, rule in enumerate(rules):
    split_rule = rule.rule.split(' ')
    for i in range(len(split_rule)):
        try:
            tmp = '%.3f'%(float(split_rule[i]))
            #print(tmp)
            split_rule[i] = str(tmp)
            
        except ValueError:
            continue
    formatted_rules.append(' '.join(split_rule))

print(formatted_rules)

# Followed almost exact same visualization code as above to create the graph 
# Had to do it separately here since the output of the SLIPPER model is a decision rule set, which is different than the other models
# Add nodes and edges based on the rules
for i, rule in enumerate(formatted_rules):
    node_id = f"rule_{i}"
    dot.node(node_id, f"Rule {i+1}\n{rule}", shape='box')
    
    # Add Yes/No branches
    yes_id = f"yes_{i}"
    no_id = f"no_{i}"
    dot.node(yes_id, "Yes", shape='ellipse')
    dot.node(no_id, "No", shape='ellipse')
    dot.edge(node_id, yes_id, label='True')
    dot.edge(node_id, no_id, label='False')
    
    # Connect 'No' to the next rule if it's not the last rule
    if i < len(rules) - 1:
        dot.edge(no_id, f"rule_{i+1}", style='dashed')

save_graph(dot, 'SLIPPER')

['perimeter error < 3.373 and worst texture < 33.188 and worst perimeter < 101.480', 'worst concave points < 0.085', 'area error < 49.044 and worst texture < 33.031 and worst perimeter < 102.460 and worst concave points < 0.178', 'fractal dimension error > 0.001 and worst concave points < 0.291 and mean smoothness > 0.063 and mean fractal dimension < 0.097', 'mean radius < 15.311 and area error < 34.449 and worst concave points < 0.110', 'radius error < 1.041 and worst perimeter < 104.580 and worst smoothness < 0.170 and mean concavity < 0.141', 'area error < 45.944 and worst texture < 31.841 and worst area < 698.800 and mean concave points < 0.064', 'mean radius < 15.359 and radius error < 0.574 and worst texture < 31.772 and worst concave points < 0.109', 'worst perimeter < 91.440', 'mean texture < 23.162 and worst concavity < 0.196']
DOT file saved as 'SLIPPER.dot'
