# Decision tree classifier evaluation (clustered data)

In [1]:
from os.path import join
import numpy as np
from validation import BinarizedModelValidation
from color_map.color_map_segm_model import ColorMapModel
from ml_models.global_params import fid
from ml_models.one_shot_random_forest import OSRandomForest
from sklearn import tree
from typing import Optional

In [2]:
%matplotlib inline

In [None]:
IMG_RES = '224x224'

In [3]:
def _print(title, metrics):
    print()
    print(title)
    for key in metrics:
        print('{} = {:.4f}'.format(key, metrics[key]))

In [4]:
def compute_random_forest(*, n_estimators: int = 10, criterion: str = 'gini', max_depth: Optional[int] = None):
    train_X = np.load(join('datasets', '224x224_anim10', 'train', 'features.npy'))
    train_Y = np.load(join('datasets', '224x224_anim10', 'train', 'output.npy'))
    forest = OSRandomForest(
        f'forest_{criterion}_maxdepth_{max_depth}',
        train_x=train_X,
        train_y=train_Y,
        n_estimators=n_estimators,
        criterion=criterion,
        max_depth=max_depth
    )
    for variant in ['train', 'test', 'train_test']:
        _print(f'{variant} | 224x224_anim10', BinarizedModelValidation(
            dataset_path=f'datasets/224x224_anim10/{variant}',
            segmentation_model=forest
        ).get_metrics())
    return forest

## Simple Dtree

In [None]:
forest_5_5 = compute_random_forest(n_estimators=5, max_depth=5)

100%|██████████| 21/21 [03:45<00:00, 10.72s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.3128
Mult IOU std = 0.3952
Bin IOU mean = 0.2016
Bin IOU std = 0.3780
Bin miss-rate mean = 0.4180
Bin miss-rate std = 0.4895
Red miss-rate mean = 0.2562
Red miss-rate std = 0.4365
Orange miss-rate mean = 0.1240
Orange miss-rate std = 0.3295
Yellow miss-rate mean = 0.3916
Yellow miss-rate std = 0.4872


100%|██████████| 15/15 [02:38<00:00, 10.58s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0420
Mult IOU std = 0.1199
Bin IOU mean = 0.0039
Bin IOU std = 0.0137
Bin miss-rate mean = 0.6990
Bin miss-rate std = 0.4507
Red miss-rate mean = 0.6744
Red miss-rate std = 0.4686
Orange miss-rate mean = 0.2035
Orange miss-rate std = 0.4026
Yellow miss-rate mean = 0.6683
Yellow miss-rate std = 0.4705


  6%|▌         | 2/36 [00:43<10:43, 18.92s/it]