In [29]:
from lore_sa.dataset import TabularDataset

In [30]:
target = 'income'

In [31]:
dataset = TabularDataset.from_csv('adult.csv', class_name = target)
dataset.df.dropna(inplace = True)

In [32]:
dataset.df.keys()

Index(['age', 'workclass', 'fnlwgt', 'education', 'educational-num',
       'marital-status', 'occupation', 'relationship', 'race', 'gender',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'income'],
      dtype='object')

In [33]:
dataset.df.drop(["marital-status", "fnlwgt", "educational-num", "occupation", "native-country"], axis=1, inplace=True)

In [34]:
dataset.update_descriptor()

In [35]:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from lore_sa.bbox import sklearn_classifier_bbox

def train_model(dataset: TabularDataset):
    numeric_indices = [v['index'] for v in dataset.descriptor['numeric'].values()]
    categorical_indices = [v['index'] for v in dataset.descriptor['categorical'].values()]
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numeric_indices),
            ('cat', OrdinalEncoder(), categorical_indices)
        ]
    )
    model = make_pipeline(preprocessor, RandomForestClassifier(n_estimators=100, random_state=42))
    
    X = dataset.df.drop([target], axis=1)
    X_train, X_test, y_train, y_test = train_test_split(X.values, dataset.df[target].values,
                test_size=0.3, random_state=42, stratify=dataset.df[target].values)
    model.fit(X_train, y_train)
    
    return sklearn_classifier_bbox.sklearnBBox(model)

In [36]:
bbox = train_model(dataset)

In [37]:
from lore_sa.lore import TabularGeneticGeneratorLore

tabularLore = TabularGeneticGeneratorLore(bbox, dataset)

In [38]:
out = []
for k in dataset.descriptor.keys():
    if k != 'target':
        out.extend(list(dataset.descriptor[k].keys()))
out, dataset.descriptor.keys()

(['age',
  'capital-gain',
  'capital-loss',
  'hours-per-week',
  'workclass',
  'education',
  'relationship',
  'race',
  'gender'],
 dict_keys(['numeric', 'categorical', 'ordinal', 'target']))

In [39]:
dataset.descriptor["target"]

{'income': {'index': 9,
  'distinct_values': ['>50K', '<=50K'],
  'count': {'>50K': 559, '<=50K': 1697}}}

In [40]:
from lore_sa.webapp import Webapp
webapp = Webapp()

In [None]:
# x = (dataset.df.drop([target], axis=1)).iloc[1]
# tabularLore.explain(x)

{'rule': {'premises': [{'attr': 'capital-gain', 'val': 7207.0, 'op': '>'}],
  'consequence': {'attr': 'income', 'val': '>50K', 'op': '='}},
 'counterfactuals': [{'premises': [{'attr': 'capital-gain',
     'val': 7207.0,
     'op': '<='},
    {'attr': 'capital-loss', 'val': 2567.0, 'op': '<='},
    {'attr': 'relationship', 'val': 'Husband', 'op': '!='}],
   'consequence': {'attr': 'income', 'val': '<=50K', 'op': '='}}],
 'fidelity': 1.0,
 'deltas': [[{'att': 'capital-gain', 'op': '<=', 'thr': 7207.0}]],
 'counterfactual_samples': [[67,
   'Federal-gov',
   'Doctorate',
   'Unmarried',
   'Black',
   'Male',
   4766,
   0,
   40],
  [28,
   'Private',
   'Doctorate',
   'Unmarried',
   'Amer-Indian-Eskimo',
   'Male',
   6789,
   0,
   28],
  [28, 'Private', 'Doctorate', 'Unmarried', 'Other', 'Male', 6789, 0, 33],
  [28, 'Private', 'Doctorate', 'Unmarried', 'Other', 'Male', 6789, 0, 33],
  [28,
   'Private',
   'Doctorate',
   'Own-child',
   'Amer-Indian-Eskimo',
   'Female',
   6789,
 

In [None]:
# webapp.launch_demo()
webapp.interactive_explanation(tabularLore.bbox, tabularLore.dataset, target)

INFO:     Started server process [29484]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


Launching LORE_sa explanation viz webapp
Starting API server on 0.0.0.0:8000
INFO:     127.0.0.1:59686 - "GET /api/get-datasets HTTP/1.1" 200 OK
API server is ready at http://localhost:8000
NPM version detected: 10.9.3
Dependencies already installed, skipping npm install
Application started successfully!
API: http://localhost:8000
Client: http://localhost:8080


INFO:     127.0.0.1:59695 - "OPTIONS /api/check-custom-data HTTP/1.1" 200 OK
INFO:     127.0.0.1:59695 - "GET /api/check-custom-data HTTP/1.1" 200 OK
INFO:     127.0.0.1:59697 - "POST /api/explain HTTP/1.1" 200 OK
INFO:     127.0.0.1:59697 - "GET /api/get-classes-colors?method=umap HTTP/1.1" 200 OK
INFO:     127.0.0.1:55873 - "POST /api/explain HTTP/1.1" 200 OK
INFO:     127.0.0.1:55873 - "GET /api/get-classes-colors?method=umap HTTP/1.1" 200 OK


In [43]:
import json
import numpy as np

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
    
print(json.dumps(dataset.descriptor, indent=4, cls=NpEncoder))

{
    "numeric": {
        "age": {
            "index": 0,
            "min": 17,
            "max": 90,
            "mean": 38.637411347517734,
            "std": 13.194862308667231,
            "median": 37.0,
            "q1": 28.0,
            "q3": 48.0
        },
        "capital-gain": {
            "index": 6,
            "min": 0,
            "max": 99999,
            "mean": 1326.9773936170213,
            "std": 8955.922794821636,
            "median": 0.0,
            "q1": 0.0,
            "q3": 0.0
        },
        "capital-loss": {
            "index": 7,
            "min": 0,
            "max": 2824,
            "mean": 71.36436170212765,
            "std": 358.22782237692525,
            "median": 0.0,
            "q1": 0.0,
            "q3": 0.0
        },
        "hours-per-week": {
            "index": 8,
            "min": 1,
            "max": 99,
            "mean": 40.77836879432624,
            "std": 12.142706230019007,
            "median": 40.0,
         