In [1]:
from lore_sa.dataset import TabularDataset

In [2]:
dataset = TabularDataset.from_csv('iris.csv', class_name = "variety")
dataset.df.dropna(inplace = True)

In [3]:
dataset.df.keys()

Index(['sepal.length', 'sepal.width', 'petal.length', 'petal.width',
       'variety'],
      dtype='object')

In [None]:
dataset.update_descriptor()

In [5]:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from lore_sa.bbox import sklearn_classifier_bbox


def train_model(dataset: TabularDataset):
    numeric_indices = [v['index'] for v in dataset.descriptor['numeric'].values()]
    categorical_indices = [v['index'] for v in dataset.descriptor['categorical'].values()]
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numeric_indices),
            ('cat', OrdinalEncoder(), categorical_indices)
        ]
    )
    model = make_pipeline(preprocessor, RandomForestClassifier(n_estimators=100, random_state=42))
    
    X_train, X_test, y_train, y_test = train_test_split(dataset.df.loc[:, 'sepal.length':'petal.width'].values, dataset.df['variety'].values,
                test_size=0.3, random_state=42, stratify=dataset.df['variety'].values)
    model.fit(X_train, y_train)
    
    return sklearn_classifier_bbox.sklearnBBox(model)

In [6]:
bbox = train_model(dataset)

In [7]:
from lore_sa.lore import TabularGeneticGeneratorLore

tabularLore = TabularGeneticGeneratorLore(bbox, dataset)

In [8]:
out = []
for k in dataset.descriptor.keys():
    if k != 'target':
        out.extend(list(dataset.descriptor[k].keys()))
out, dataset.descriptor.keys()

(['sepal.length', 'sepal.width', 'petal.length', 'petal.width'],
 dict_keys(['numeric', 'categorical', 'ordinal', 'target']))

In [9]:
dataset.descriptor["target"]

{'variety': {'index': 4,
  'distinct_values': ['Setosa', 'Versicolor', 'Virginica'],
  'count': {'Setosa': 50, 'Versicolor': 50, 'Virginica': 50}}}

In [10]:
from lore_sa.webapp import Webapp
webapp = Webapp()

In [None]:
webapp.launch_demo()
# webapp.interactive_explanation(tabularLore.bbox, tabularLore.dataset, 'variety')

INFO:     Started server process [15580]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


Launching LORE_sa explanation viz webapp
Starting API server on 0.0.0.0:8000
INFO:     127.0.0.1:49689 - "GET /api/get-datasets HTTP/1.1" 200 OK
API server is ready at http://localhost:8000
Application started successfully!
API: http://localhost:8000
Client: http://localhost:8080
Opening http://localhost:8080 in your default browser...
Browser opened successfully!


INFO:     127.0.0.1:61271 - "GET /api/check-custom-data HTTP/1.1" 200 OK
INFO:     127.0.0.1:61279 - "GET /api/get-datasets HTTP/1.1" 200 OK
INFO:     127.0.0.1:61271 - "GET /api/check-custom-data HTTP/1.1" 200 OK
INFO:     127.0.0.1:61271 - "GET /api/get-datasets HTTP/1.1" 200 OK
INFO:     127.0.0.1:61292 - "GET /api/get-datasets HTTP/1.1" 200 OK
INFO:     127.0.0.1:61271 - "GET /api/check-custom-data HTTP/1.1" 200 OK
INFO:     127.0.0.1:61303 - "GET /api/get-datasets HTTP/1.1" 200 OK
INFO:     127.0.0.1:60010 - "GET /api/check-custom-data HTTP/1.1" 200 OK
INFO:     127.0.0.1:60010 - "GET /api/get-datasets HTTP/1.1" 200 OK
