In [1]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import os
import sys
BASE_PATH = os.path.realpath('..')
DATASETS_DIR = os.path.join(BASE_PATH, 'datasets')
LIB_DIR = os.path.join(BASE_PATH,'lib')
if LIB_DIR[:-3] not in sys.path:
    sys.path.append(LIB_DIR[:-3])

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

from importlib import reload

In [2]:
from lib import fca_interp as fcai
from lib.utils_ import powerset
from importlib import reload

In [3]:
from sklearn.datasets import load_iris

In [4]:
iris_data = load_iris()

In [5]:
iris_ds = pd.DataFrame(iris_data['data'], columns=iris_data['feature_names'])
iris_ds['class_id'] = iris_data['target']
iris_ds['class'] = [iris_data['target_names'][x] for x in iris_ds['class_id']]

In [6]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier

In [7]:
rf = DecisionTreeClassifier() #RandomForestClassifier()
rf.fit(iris_ds.drop(['class_id','class'],1), iris_ds['class_id'])
rf.feature_importances_

array([0.01333333, 0.01333333, 0.05072262, 0.92261071])

In [8]:
fs = ['petal width (cm)','petal length (cm)', ]

In [9]:
rf.feature_importances_

array([0.01333333, 0.01333333, 0.05072262, 0.92261071])

In [15]:
iris_ds.index = iris_ds.index.astype(str)

In [16]:
mvcntx = fcai.MultiValuedContext(iris_ds[fs], y_true=iris_ds['class_id'], cat_attrs=[])

In [17]:
fm = fcai.FormalManager(mvcntx, task_type='')
fm.construct_concepts(algo='RandomForest', rf_params={'min_samples_leaf':10})

HBox(children=(FloatProgress(value=0.0, description='Postprocessing', max=101.0, style=ProgressStyle(descripti…




In [23]:
fm.construct_lattice(only_spanning_tree=True)

In [27]:
c = fm.get_concept_by_id(10)

In [32]:
set([c._up_neighb_st])

{5}

In [159]:
for c in fm.get_concepts():
    for cl_id in range(3):
        c._metrics[f'prob_{cl_id}'] = (iris_ds.loc[[str(g) for g in c.get_extent()], 'class_id'] ==cl_id).mean()
    c._metrics['class_id'] = np.argmax([c._metrics[f"prob_{cl_id}"] for cl_id in range(3)])
    c._metrics['class_prob'] = max([c._metrics[f"prob_{cl_id}"] for cl_id in range(3)])
    c._metrics['class_name'] = iris_data['target_names'][c._metrics['class_id']]

In [33]:
for c in fm.get_concepts():
    c._low_neighbs = c._low_neighbs_st
    c._up_neighbs = set([c._up_neighb_st])

In [34]:
fig = fm.get_plotly_fig()

In [35]:
fig

In [51]:
test_ds = pd.DataFrame([(2,3)], columns=fs)
test_cntx = fcai.MultiValuedContext(test_ds, cat_attrs=[])

In [52]:
def trace_context(self, cntx):
    cncpts_exts = {}
    def get_extent(c):
        if c.get_id() not in cncpts_exts:
            cncpts_exts[c.get_id()] = set(cntx.get_extent(c.get_intent(), verb=False))
        return cncpts_exts[c.get_id()]

    
    cncpts_to_check = set([0])
    cncpts_dict = {c.get_id(): c for c in self.get_concepts()}
    objs_dict = {g: idx for idx, g in enumerate(cntx.get_objs())}
    obj_preds_cncpts = {idx:[] for idx in range(len(cntx.get_objs()))}
    obj_all_cncpts = {idx:[] for idx in range(len(cntx.get_objs()))}
    
    
    for i in range(len(self.get_concepts())):
        if len(cncpts_to_check) == 0:
            break

        c_id = min(cncpts_to_check)
        cncpts_to_check.remove(c_id)

        c = cncpts_dict[c_id]
        ext = get_extent(c)

        ext_ln = set()
        for ln_id in c._low_neighbs_st:
            ln = cncpts_dict[ln_id]
            ext_ln |= get_extent(ln)
        ext_to_stop = ext-ext_ln
        
        for g in ext_to_stop:
            obj_preds_cncpts[g].append(c_id)
        
        for g in ext:
            obj_all_cncpts[g].append(c_id)
        
        cncpts_to_check |= set([ln_id for ln_id in c._low_neighbs_st #c.get_lower_neighbs()
                                if len(get_extent(cncpts_dict[ln_id]))>0 ] )

    return obj_preds_cncpts, obj_all_cncpts

In [53]:
trace_context(fm, test_cntx)

({0: [7]}, {0: [0, 2, 4, 7]})

In [250]:
metric = ['prob_0','prob_1','prob_2']

In [265]:
def predict_context(self, cntx, metric='mean_y_true'):
    from scipy.sparse import csr_matrix
    
    obj_preds_cncpts = trace_context(self, cntx)[0]
    
    metric = metric if type(metric) == list else [metric]
    mvals = np.array([[c._metrics[m] for m in metric] for c in sorted(self.get_concepts(), key=lambda c: c.get_id())])
    
    preds = []
    for m_id, m in enumerate(metric):
        X_ = np.array([(mvals[c_id][m_id], g, c_id) for g, c_ids in obj_preds_cncpts.items() for c_id in c_ids])
        X = csr_matrix((X_[:,0], (X_[:,1].astype(int), X_[:,2].astype(int))), 
                   shape=(len(obj_preds_cncpts), len(fm.get_concepts())))

        W_ = np.array([(1/len(c_ids), g, c_id) for g, c_ids in obj_preds_cncpts.items() for c_id in c_ids] )
        W = csr_matrix((W_[:,0], (W_[:,1].astype(int), W_[:,2].astype(int))),
            shape=(len(obj_preds_cncpts), len(fm.get_concepts())))

        preds_ = X.multiply(W).sum(1)
        preds.append(preds_)
    preds = np.concatenate(preds,1)
    
    return preds

In [268]:
predict_context(fm, mvcntx, metric)

matrix([[0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0.00526316],
        [0.76315789, 0.23157895, 0

# Production way

In [275]:
fcai = reload(fcai)

In [276]:
mvcntx = fcai.MultiValuedContext(iris_ds.drop(['class_id','class'],1), cat_attrs=[], y_true=iris_ds['class_id'])
fm = fcai.FormalManager(mvcntx)

In [280]:
fm.construct_concepts(algo='RandomForest', rf_params={'random_state':42})
fm.construct_lattice(only_spanning_tree=True, use_tqdm=True,)

HBox(children=(FloatProgress(value=0.0, description='Postprocessing', max=190.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='construct spanning tree', max=190.0, style=ProgressStyle(…




In [282]:
fm.trace_context(mvcntx)[0]

{0: [19],
 1: [19],
 2: [19],
 3: [19],
 4: [19],
 5: [19],
 6: [19],
 7: [19],
 8: [19],
 9: [19],
 10: [19],
 11: [19],
 12: [19],
 13: [19],
 14: [19],
 15: [19],
 16: [19],
 17: [19],
 18: [19],
 19: [19],
 20: [19],
 21: [19],
 22: [19],
 23: [19],
 24: [19],
 25: [19],
 26: [19],
 27: [19],
 28: [19],
 29: [19],
 30: [19],
 31: [19],
 32: [19],
 33: [19],
 34: [19],
 35: [19],
 36: [19],
 37: [19],
 38: [19],
 39: [19],
 40: [19],
 41: [19],
 42: [19],
 43: [19],
 44: [19],
 45: [19],
 46: [19],
 47: [19],
 48: [19],
 49: [19],
 50: [3, 11, 12, 16, 17, 18, 20, 26, 28, 31, 36, 37, 40, 41, 49],
 51: [3, 11, 12, 16, 17, 18, 20, 26, 28, 31, 36, 37, 40, 41, 49],
 52: [7,
  11,
  12,
  16,
  17,
  18,
  20,
  23,
  26,
  31,
  37,
  41,
  42,
  47,
  70,
  71,
  78,
  89,
  90,
  92,
  94,
  95,
  96,
  109,
  112,
  115,
  116,
  117,
  130,
  132,
  135,
  152,
  175,
  177,
  184],
 53: [3, 11, 12, 16, 17, 18, 20, 26, 28, 31, 36, 37, 40, 41, 49],
 54: [3, 11, 12, 16, 17, 18, 20, 26,

In [283]:
metric = ['prob_0','prob_1','prob_2']

In [288]:
for c in fm.get_concepts():
    for cl_id in range(3):
        c._metrics[f'prob_{cl_id}'] = (iris_ds.loc[[str(g) for g in c.get_extent()], 'class_id'] ==cl_id).mean()
    c._metrics['class_id'] = np.argmax([c._metrics[f"prob_{cl_id}"] for cl_id in range(3)])
    c._metrics['class_prob'] = max([c._metrics[f"prob_{cl_id}"] for cl_id in range(3)])
    c._metrics['class_name'] = iris_data['target_names'][c._metrics['class_id']]

In [289]:
fm.predict_context(mvcntx, metric)

matrix([[1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0