In [1]:
import numpy as np
import pandas as pd

from catboost import CatBoostClassifier, Pool, CatBoostRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import os
import sys
BASE_PATH = os.path.realpath('..')
DATASETS_DIR = os.path.join(BASE_PATH, 'datasets')
LIB_DIR = os.path.join(BASE_PATH,'lib')
if LIB_DIR[:-3] not in sys.path:
    sys.path.append(LIB_DIR[:-3])

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

from importlib import reload

In [2]:
RESULTS_DIR = os.path.join(BASE_PATH, 'results')

In [3]:
from lib import fca_interp as fcai
from lib.utils_ import powerset
from importlib import reload

In [4]:
from copy import copy, deepcopy

# Load Data

In [5]:
ds = pd.read_csv(DATASETS_DIR+'/mango.csv').set_index('title').drop('fruit',1)
cat_attrs = ['color','firm','smooth','form']
train_attrs = ['color','firm','smooth','form']
for f in cat_attrs:
    ds[f] = ds[f].astype(str)
print(ds.shape)
ds.head()

(8, 4)


Unnamed: 0_level_0,color,firm,smooth,form
title,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
apple,yellow,False,True,round
grapefruit,yellow,False,False,round
kiwi,green,False,False,oval
plum,blue,False,True,oval
toy cube,green,True,True,cubic


In [6]:
ds = pd.read_csv(DATASETS_DIR+'/bank/bank.csv', sep=';')
cat_attrs = ['job','marital','education','default', 'housing','loan','contact','month']
train_attrs = ['age','job','marital','education','default','balance','housing','loan','contact','day','month','duration','campaign','pdays','previous']
for f in cat_attrs:
    ds[f] = ds[f].astype(str)
print(ds.shape)
ds.head()

(4521, 17)


Unnamed: 0,age,job,marital,education,default,balance,housing,loan,contact,day,month,duration,campaign,pdays,previous,poutcome,y
0,30,unemployed,married,primary,no,1787,no,no,cellular,19,oct,79,1,-1,0,unknown,no
1,33,services,married,secondary,no,4789,yes,yes,cellular,11,may,220,1,339,4,failure,no
2,35,management,single,tertiary,no,1350,yes,no,cellular,16,apr,185,1,330,1,failure,no
3,30,management,married,tertiary,no,1476,yes,yes,unknown,3,jun,199,4,-1,0,unknown,no
4,59,blue-collar,married,secondary,no,0,yes,no,unknown,5,may,226,1,-1,0,unknown,no


# Test Lattices

In [7]:
fcai = reload(fcai)

In [8]:
mvcntx = fcai.MultiValuedContext(ds[train_attrs].head(13), cat_attrs=cat_attrs)

In [9]:
fm = fcai.FormalManager(mvcntx)
fm.construct_concepts(use_tqdm=True)

HBox(children=(FloatProgress(value=0.0, description='Postprocessing', max=6107.0, style=ProgressStyle(descript…




In [10]:
%%time
fm._construct_lattice_connections()
lattice_rely = {c.get_id(): c._up_neighbs for c in fm.get_concepts()}

HBox(children=(FloatProgress(value=0.0, description='construct lattice connections', max=6107.0, style=Progres…


CPU times: user 11min 52s, sys: 1.66 s, total: 11min 54s
Wall time: 11min 51s


In [11]:
%%time
fm._construct_spanning_tree()
aun = fm._construct_lattice_from_spanning_tree(use_tqdm=True)
lattice_test = {c.get_id(): c._up_neighbs for c in fm.get_concepts()}

HBox(children=(FloatProgress(value=0.0, description='construct spanning tree', max=6107.0, style=ProgressStyle…




HBox(children=(FloatProgress(value=0.0, description='iterate through chains', max=2899.0, style=ProgressStyle(…


CPU times: user 1min 19s, sys: 353 ms, total: 1min 19s
Wall time: 1min 19s


In [12]:
diff_concepts  = [lattice_rely[c.get_id()]!=lattice_test[c.get_id()] for c in fm.sort_concepts()]
np.sum(diff_concepts), np.mean(diff_concepts)

(0, 0.0)

In [13]:
for c in fm.sort_concepts():
    c_id = c.get_id()
    if lattice_rely[c_id]!=lattice_test[c_id]:
        print(f'Concept {c_id}\nTrue up neighbs: {lattice_rely[c_id]}\nEstimated up neighbs: {lattice_test[c_id]}')
        print('=======================================')
        break

In [14]:
fm.get_concept_by_id(26).is_subconcept_of(fm.get_concept_by_id(9))

False

In [15]:
fm._get_chains()

[{0, 2, 35, 128, 417, 993, 2466, 3624, 4653, 5400, 5841, 6081, 6106},
 {0, 1, 34, 194, 544, 1152, 2626, 3976, 5066, 5730, 6008, 6089, 6105},
 {0, 1, 34, 194, 544, 1152, 2626, 3976, 5066, 5720, 5997, 6082, 6104},
 {0, 1, 34, 194, 544, 1508, 2632, 3846, 5061, 5712, 5992, 6080, 6103},
 {0, 1, 34, 194, 544, 1152, 2626, 3976, 5036, 5683, 5971, 6073, 6102},
 {0, 1, 41, 133, 354, 836, 1633, 2891, 4066, 5359, 5936, 6061, 6101},
 {0, 1, 34, 194, 516, 1089, 1921, 3202, 4943, 5544, 5900, 6052, 6100},
 {0, 1, 34, 194, 544, 1152, 2626, 3851, 4864, 5529, 5889, 6049, 6099},
 {0, 1, 34, 194, 544, 1152, 2626, 3976, 5066, 5730, 6008, 6040, 6098},
 {0, 1, 34, 194, 544, 1152, 2626, 3976, 5066, 5730, 6009, 6088, 6097},
 {0, 1, 34, 194, 544, 1152, 2626, 3976, 5066, 5730, 6008, 6040, 6096},
 {0, 1, 34, 233, 609, 1289, 2221, 3437, 4480, 5280, 5800, 6028, 6095},
 {0, 1, 34, 194, 544, 1152, 2051, 3309, 4403, 5260, 5746, 6017, 6094},
 {0, 3, 36, 129, 418, 994, 2467, 3626, 4656, 5417, 5853, 6093},
 {0, 2, 35, 128