In [1]:
import numpy as np
import pandas as pd

from catboost import CatBoostClassifier, Pool, CatBoostRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import os
import sys
BASE_PATH = os.path.realpath('..')
DATASETS_DIR = os.path.join(BASE_PATH, 'datasets')
LIB_DIR = os.path.join(BASE_PATH,'lib')
if LIB_DIR[:-3] not in sys.path:
    sys.path.append(LIB_DIR[:-3])

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

from importlib import reload

In [2]:
RESULTS_DIR = os.path.join(BASE_PATH, 'results')

In [3]:
from lib import fca_interp as fcai

In [4]:
from importlib import reload

# Bank ds

In [5]:
real_feats = ['age', 'default', 'housing', 'loan', 'campaign', 'pdays', 'previous', 'balance']
cat_feats  = ['job', 'marital', 'education', 'contact', 'month', 'poutcome','day']

In [6]:
bank_ds = pd.read_csv('bank_ds_new.csv', index_col=0)
bank_ds.index = bank_ds.index.astype(str)

In [7]:
bank_ds[cat_feats] = bank_ds[cat_feats].astype(str)

In [8]:
with open('bank_ds_test_indexes.txt','r') as f:
    test_idxs = f.read().split(',')

In [9]:
ds = bank_ds.loc[test_idxs]
print(ds.shape)
ds.head()

(1131, 18)


Unnamed: 0,age,job,marital,education,default,balance,housing,loan,contact,day,month,duration,campaign,pdays,previous,poutcome,y,preds
2398,51,entrepreneur,married,secondary,1,-2082,0,1,cellular,28,jul,123,6,,0,unknown,0,0
800,50,management,married,tertiary,0,2881,0,0,cellular,5,aug,510,2,2.0,5,other,0,0
2288,50,technician,married,secondary,0,1412,0,0,cellular,6,aug,131,3,,0,unknown,0,0
2344,37,management,married,tertiary,0,0,1,0,unknown,3,jun,247,13,,0,unknown,0,0
3615,31,admin.,single,secondary,0,757,0,0,cellular,3,feb,343,2,,0,unknown,0,0


In [10]:
cntx_full = fcai.MultiValuedContext(ds[cat_feats+real_feats].head(50), cat_attrs=cat_feats)
s = (ds['preds']==1).head(50)
s = s[s].index
cntx_train_pos = fcai.MultiValuedContext(ds.loc[s, cat_feats+real_feats], cat_attrs=cat_feats)
s = (ds['preds']==0).head(50)
s = s[s].index
cntx_train_neg = fcai.MultiValuedContext(ds.loc[s, cat_feats+real_feats], cat_attrs=cat_feats)

cntx_test = fcai.MultiValuedContext(ds[cat_feats+real_feats].tail(50), cat_attrs=cat_feats)

In [11]:
fcai = reload(fcai)

In [12]:
fm = fcai.FormalManager(cntx_train_pos, context_full=cntx_full)

In [13]:
fm.construct_concepts()

HBox(children=(FloatProgress(value=0.0, description='Postprocessing', max=7.0, style=ProgressStyle(description…




In [14]:
concepts_pos = fm.get_concepts()

In [15]:
fm = fcai.FormalManager(cntx_train_neg, context_full=cntx_full)

In [16]:
fm.construct_concepts(algo='FromMaxConcepts_Bootstrap', n_bootstrap_epochs='2times', sample_size_bootstrap=5)

HBox(children=(FloatProgress(value=0.0, description='construct max strong hyps', max=47.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=0.0, description='agglomerative construction', max=11.0, style=ProgressStyl…

HBox(children=(FloatProgress(value=0.0, description='boostrap aggregating', max=18.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=494.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, max=494.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=16.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='boostrap aggregating', max=6.0, style=ProgressStyle(descr…






HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=147.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, max=147.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=6.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='boostrap aggregating', max=2.0, style=ProgressStyle(descr…






HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=53.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…








HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))





HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))







HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='boostrap aggregating', max=1.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=23.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='pruning', max=2.0, style=ProgressStyle(description_width=…






HBox(children=(FloatProgress(value=0.0, description='Postprocessing', max=22.0, style=ProgressStyle(descriptio…




In [17]:
concepts_neg = fm.get_concepts()

In [18]:
len(concepts_pos)

7

In [19]:
len(concepts_neg)

22

In [20]:
fm = fcai.FormalManager(cntx_full)
fm._concepts = concepts_pos|concepts_neg
for c in fm.get_concepts():
    ext_ = fm.get_context().get_extent(c.get_intent())
    int_ = fm.get_context().get_intent(ext_)
    c._extent = ext_
    c._intent = int_

In [21]:
for idx, c in enumerate(fm.sort_concepts()):
    c._idx = idx

In [22]:
for c in fm.get_concepts():
    if len(c.get_extent())>0:
        c._metrics['mean_y_true'] = ds.loc[c.get_extent(), 'y'].mean()
        c._metrics['mean_y_pred'] = ds.loc[c.get_extent(), 'preds'].mean()

In [23]:
fm.predict_context(cntx_test, aggfunc='median')[:10]

[None,
 0.07692307692307693,
 0.07692307692307693,
 0.06976744186046512,
 None,
 0.0,
 0.05846153846153847,
 0.17647058823529413,
 0.07692307692307693,
 None]

# Убираем объекты из существующих понятий

In [24]:
from copy import copy, deepcopy

In [35]:
cntx_full.get_objs()

array(['2398', '800', '2288', '2344', '3615', '3548', '1115', '4053',
       '838', '4141', '1189', '1461', '3819', '3614', '179', '4011',
       '4237', '1321', '4018', '2174', '3134', '1878', '1485', '3963',
       '937', '2401', '1876', '415', '1476', '3471', '2179', '2330',
       '2754', '2792', '29', '1222', '199', '3462', '1583', '3541',
       '3707', '3303', '1181', '4077', '3964', '2916', '3214', '1610',
       '2526', '1902'], dtype='<U4')

In [43]:
cntx_red = copy(cntx_full)
cntx_red.drop_objects_from_context(['1461', '1321', '2916', '3214', '2526'])

In [44]:
fm_red = fcai.FormalManager(cntx_red)

In [45]:
fm_red.set_concepts(fm.get_concepts())

In [46]:
len(fm_red.get_concepts())

28

In [47]:
len(fm.get_concepts())

29