# Evaluation of different ML models

In [1]:
from os.path import join
import numpy as np
from validation import BinarizedModelValidation
from color_map.color_map_segm_model import ColorMapModel
from ml_models.global_params import fid
from ml_models.one_shot_knn import OneShotKNN
# from ml_models.two_shot_knn import TwoShotKNN

In [2]:
def _print(title, metrics):
    print()
    print(title)
    for key in metrics:
        print('{} = {:.4f}'.format(key, metrics[key]))

In [3]:
def compute_oneshot_knn(nn: int, weights: str):
    for variant in ['train', 'test', 'train_test']:
        train_X = np.load(join('datasets', '224x224_anim10', 'cluster_train', 'features.npy'))
        train_Y = np.load(join('datasets', '224x224_anim10', 'cluster_train', 'output.npy'))
        _print(f'{variant} | 224x224_anim10', BinarizedModelValidation(
            dataset_path=f'datasets/224x224_anim10/{variant}',
            segmentation_model=OneShotKNN(
                f'one-shot_knn_clu_224x224[{nn}_{weights}]', 
                num_neighb=nn, 
                weights_type=weights,
                train_x=train_X,
                train_y=train_Y
            )
        ).get_metrics())

### One-shot KNN classifier

#### 1. NN=3

In [4]:
compute_oneshot_knn(nn=3, weights='uniform')

100%|██████████| 21/21 [13:29<00:00, 38.53s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1314
Mult IOU std = 0.1607
Bin IOU mean = 0.0265
Bin IOU std = 0.0425
Bin miss-rate mean = 0.9730
Bin miss-rate std = 0.0435
Red miss-rate mean = 0.8987
Red miss-rate std = 0.1544
Orange miss-rate mean = 0.2645
Orange miss-rate std = 0.4410
Yellow miss-rate mean = 0.9925
Yellow miss-rate std = 0.0116


100%|██████████| 15/15 [09:48<00:00, 39.24s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0359
Mult IOU std = 0.0604
Bin IOU mean = 0.0824
Bin IOU std = 0.0784
Bin miss-rate mean = 0.9140
Bin miss-rate std = 0.0850
Red miss-rate mean = 0.8427
Red miss-rate std = 0.1785
Orange miss-rate mean = 0.4190
Orange miss-rate std = 0.4876
Yellow miss-rate mean = 0.9772
Yellow miss-rate std = 0.0272


100%|██████████| 36/36 [23:08<00:00, 38.56s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0917
Mult IOU std = 0.1372
Bin IOU mean = 0.0498
Bin IOU std = 0.0661
Bin miss-rate mean = 0.9484
Bin miss-rate std = 0.0704
Red miss-rate mean = 0.8754
Red miss-rate std = 0.1671
Orange miss-rate mean = 0.3287
Orange miss-rate std = 0.4672
Yellow miss-rate mean = 0.9862
Yellow miss-rate std = 0.0210





In [5]:
compute_oneshot_knn(nn=3, weights='distance')

100%|██████████| 21/21 [09:11<00:00, 26.24s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1308
Mult IOU std = 0.1597
Bin IOU mean = 0.0267
Bin IOU std = 0.0427
Bin miss-rate mean = 0.9729
Bin miss-rate std = 0.0436
Red miss-rate mean = 0.9003
Red miss-rate std = 0.1524
Orange miss-rate mean = 0.2727
Orange miss-rate std = 0.4454
Yellow miss-rate mean = 0.9925
Yellow miss-rate std = 0.0116


100%|██████████| 15/15 [06:38<00:00, 26.57s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0334
Mult IOU std = 0.0542
Bin IOU mean = 0.0829
Bin IOU std = 0.0791
Bin miss-rate mean = 0.9136
Bin miss-rate std = 0.0855
Red miss-rate mean = 0.8444
Red miss-rate std = 0.1770
Orange miss-rate mean = 0.4598
Orange miss-rate std = 0.4925
Yellow miss-rate mean = 0.9772
Yellow miss-rate std = 0.0272


100%|██████████| 36/36 [15:49<00:00, 26.38s/it]


train_test | 224x224_anim10
Mult IOU mean = 0.0903
Mult IOU std = 0.1358
Bin IOU mean = 0.0500
Bin IOU std = 0.0666
Bin miss-rate mean = 0.9482
Bin miss-rate std = 0.0707
Red miss-rate mean = 0.8771
Red miss-rate std = 0.1654
Orange miss-rate mean = 0.3505
Orange miss-rate std = 0.4746
Yellow miss-rate mean = 0.9861
Yellow miss-rate std = 0.0210





#### 2. NN=5

In [6]:
compute_oneshot_knn(nn=5, weights='uniform')

100%|██████████| 21/21 [13:43<00:00, 39.20s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1291
Mult IOU std = 0.1553
Bin IOU mean = 0.0265
Bin IOU std = 0.0423
Bin miss-rate mean = 0.9729
Bin miss-rate std = 0.0435
Red miss-rate mean = 0.8973
Red miss-rate std = 0.1548
Orange miss-rate mean = 0.2521
Orange miss-rate std = 0.4342
Yellow miss-rate mean = 0.9926
Yellow miss-rate std = 0.0115


100%|██████████| 15/15 [09:56<00:00, 39.78s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0347
Mult IOU std = 0.0554
Bin IOU mean = 0.0831
Bin IOU std = 0.0797
Bin miss-rate mean = 0.9131
Bin miss-rate std = 0.0867
Red miss-rate mean = 0.8351
Red miss-rate std = 0.1823
Orange miss-rate mean = 0.4733
Orange miss-rate std = 0.4920
Yellow miss-rate mean = 0.9775
Yellow miss-rate std = 0.0266


100%|██████████| 36/36 [23:40<00:00, 39.45s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0899
Mult IOU std = 0.1324
Bin IOU mean = 0.0500
Bin IOU std = 0.0668
Bin miss-rate mean = 0.9480
Bin miss-rate std = 0.0714
Red miss-rate mean = 0.8714
Red miss-rate std = 0.1696
Orange miss-rate mean = 0.3440
Orange miss-rate std = 0.4719
Yellow miss-rate mean = 0.9863
Yellow miss-rate std = 0.0206





In [7]:
compute_oneshot_knn(nn=5, weights='distance')

100%|██████████| 21/21 [09:33<00:00, 27.29s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1331
Mult IOU std = 0.1642
Bin IOU mean = 0.0267
Bin IOU std = 0.0426
Bin miss-rate mean = 0.9728
Bin miss-rate std = 0.0436
Red miss-rate mean = 0.8933
Red miss-rate std = 0.1673
Orange miss-rate mean = 0.2521
Orange miss-rate std = 0.4342
Yellow miss-rate mean = 0.9925
Yellow miss-rate std = 0.0116


100%|██████████| 15/15 [07:00<00:00, 28.05s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0324
Mult IOU std = 0.0529
Bin IOU mean = 0.0832
Bin IOU std = 0.0797
Bin miss-rate mean = 0.9133
Bin miss-rate std = 0.0860
Red miss-rate mean = 0.8430
Red miss-rate std = 0.1810
Orange miss-rate mean = 0.4715
Orange miss-rate std = 0.4934
Yellow miss-rate mean = 0.9771
Yellow miss-rate std = 0.0271


100%|██████████| 36/36 [16:31<00:00, 27.54s/it]


train_test | 224x224_anim10
Mult IOU mean = 0.0913
Mult IOU std = 0.1392
Bin IOU mean = 0.0502
Bin IOU std = 0.0669
Bin miss-rate mean = 0.9481
Bin miss-rate std = 0.0711
Red miss-rate mean = 0.8724
Red miss-rate std = 0.1749
Orange miss-rate mean = 0.3432
Orange miss-rate std = 0.4722
Yellow miss-rate mean = 0.9861
Yellow miss-rate std = 0.0210





#### 3. NN=7

In [8]:
compute_oneshot_knn(nn=7, weights='uniform')

100%|██████████| 21/21 [13:57<00:00, 39.90s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1350
Mult IOU std = 0.1684
Bin IOU mean = 0.0270
Bin IOU std = 0.0432
Bin miss-rate mean = 0.9724
Bin miss-rate std = 0.0444
Red miss-rate mean = 0.8940
Red miss-rate std = 0.1738
Orange miss-rate mean = 0.2645
Orange miss-rate std = 0.4410
Yellow miss-rate mean = 0.9922
Yellow miss-rate std = 0.0121


100%|██████████| 15/15 [10:09<00:00, 40.64s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0340
Mult IOU std = 0.0596
Bin IOU mean = 0.0844
Bin IOU std = 0.0802
Bin miss-rate mean = 0.9118
Bin miss-rate std = 0.0871
Red miss-rate mean = 0.8481
Red miss-rate std = 0.1803
Orange miss-rate mean = 0.4669
Orange miss-rate std = 0.4922
Yellow miss-rate mean = 0.9759
Yellow miss-rate std = 0.0293


100%|██████████| 36/36 [24:11<00:00, 40.32s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0930
Mult IOU std = 0.1433
Bin IOU mean = 0.0508
Bin IOU std = 0.0676
Bin miss-rate mean = 0.9472
Bin miss-rate std = 0.0721
Red miss-rate mean = 0.8749
Red miss-rate std = 0.1779
Orange miss-rate mean = 0.3486
Orange miss-rate std = 0.4736
Yellow miss-rate mean = 0.9854
Yellow miss-rate std = 0.0225





In [9]:
compute_oneshot_knn(nn=7, weights='distance')

100%|██████████| 21/21 [09:52<00:00, 28.19s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1339
Mult IOU std = 0.1665
Bin IOU mean = 0.0268
Bin IOU std = 0.0429
Bin miss-rate mean = 0.9726
Bin miss-rate std = 0.0440
Red miss-rate mean = 0.8940
Red miss-rate std = 0.1739
Orange miss-rate mean = 0.2645
Orange miss-rate std = 0.4410
Yellow miss-rate mean = 0.9923
Yellow miss-rate std = 0.0120


100%|██████████| 15/15 [07:13<00:00, 28.91s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0332
Mult IOU std = 0.0583
Bin IOU mean = 0.0841
Bin IOU std = 0.0804
Bin miss-rate mean = 0.9122
Bin miss-rate std = 0.0870
Red miss-rate mean = 0.8493
Red miss-rate std = 0.1773
Orange miss-rate mean = 0.4654
Orange miss-rate std = 0.4932
Yellow miss-rate mean = 0.9761
Yellow miss-rate std = 0.0291


100%|██████████| 36/36 [17:07<00:00, 28.54s/it]


train_test | 224x224_anim10
Mult IOU mean = 0.0921
Mult IOU std = 0.1417
Bin IOU mean = 0.0506
Bin IOU std = 0.0675
Bin miss-rate mean = 0.9475
Bin miss-rate std = 0.0719
Red miss-rate mean = 0.8755
Red miss-rate std = 0.1767
Orange miss-rate mean = 0.3479
Orange miss-rate std = 0.4739
Yellow miss-rate mean = 0.9856
Yellow miss-rate std = 0.0224





#### 4. NN=9

In [10]:
compute_oneshot_knn(nn=9, weights='uniform')

100%|██████████| 21/21 [14:15<00:00, 40.74s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1301
Mult IOU std = 0.1565
Bin IOU mean = 0.0270
Bin IOU std = 0.0432
Bin miss-rate mean = 0.9724
Bin miss-rate std = 0.0445
Red miss-rate mean = 0.8892
Red miss-rate std = 0.1644
Orange miss-rate mean = 0.2603
Orange miss-rate std = 0.4388
Yellow miss-rate mean = 0.9924
Yellow miss-rate std = 0.0120


100%|██████████| 15/15 [10:23<00:00, 41.59s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0378
Mult IOU std = 0.0655
Bin IOU mean = 0.0848
Bin IOU std = 0.0804
Bin miss-rate mean = 0.9112
Bin miss-rate std = 0.0874
Red miss-rate mean = 0.8206
Red miss-rate std = 0.1948
Orange miss-rate mean = 0.4499
Orange miss-rate std = 0.4903
Yellow miss-rate mean = 0.9755
Yellow miss-rate std = 0.0307


100%|██████████| 36/36 [24:37<00:00, 41.03s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0917
Mult IOU std = 0.1348
Bin IOU mean = 0.0510
Bin IOU std = 0.0677
Bin miss-rate mean = 0.9470
Bin miss-rate std = 0.0724
Red miss-rate mean = 0.8607
Red miss-rate std = 0.1808
Orange miss-rate mean = 0.3391
Orange miss-rate std = 0.4703
Yellow miss-rate mean = 0.9854
Yellow miss-rate std = 0.0233





In [11]:
compute_oneshot_knn(nn=9, weights='distance')

100%|██████████| 21/21 [10:07<00:00, 28.91s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1352
Mult IOU std = 0.1684
Bin IOU mean = 0.0270
Bin IOU std = 0.0432
Bin miss-rate mean = 0.9725
Bin miss-rate std = 0.0443
Red miss-rate mean = 0.8878
Red miss-rate std = 0.1785
Orange miss-rate mean = 0.2603
Orange miss-rate std = 0.4388
Yellow miss-rate mean = 0.9923
Yellow miss-rate std = 0.0120


100%|██████████| 15/15 [07:25<00:00, 29.67s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0348
Mult IOU std = 0.0590
Bin IOU mean = 0.0842
Bin IOU std = 0.0802
Bin miss-rate mean = 0.9121
Bin miss-rate std = 0.0868
Red miss-rate mean = 0.8359
Red miss-rate std = 0.1882
Orange miss-rate mean = 0.4490
Orange miss-rate std = 0.4910
Yellow miss-rate mean = 0.9754
Yellow miss-rate std = 0.0309


100%|██████████| 36/36 [17:31<00:00, 29.20s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0935
Mult IOU std = 0.1430
Bin IOU mean = 0.0508
Bin IOU std = 0.0675
Bin miss-rate mean = 0.9474
Bin miss-rate std = 0.0719
Red miss-rate mean = 0.8662
Red miss-rate std = 0.1844
Orange miss-rate mean = 0.3387
Orange miss-rate std = 0.4705
Yellow miss-rate mean = 0.9853
Yellow miss-rate std = 0.0234





#### 5. NN=11

In [12]:
compute_oneshot_knn(nn=11, weights='uniform')

100%|██████████| 21/21 [14:37<00:00, 41.77s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1365
Mult IOU std = 0.1706
Bin IOU mean = 0.0271
Bin IOU std = 0.0433
Bin miss-rate mean = 0.9723
Bin miss-rate std = 0.0445
Red miss-rate mean = 0.8718
Red miss-rate std = 0.2041
Orange miss-rate mean = 0.2686
Orange miss-rate std = 0.4432
Yellow miss-rate mean = 0.9922
Yellow miss-rate std = 0.0123


100%|██████████| 15/15 [10:36<00:00, 42.43s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0375
Mult IOU std = 0.0661
Bin IOU mean = 0.0844
Bin IOU std = 0.0800
Bin miss-rate mean = 0.9117
Bin miss-rate std = 0.0870
Red miss-rate mean = 0.8172
Red miss-rate std = 0.2078
Orange miss-rate mean = 0.4510
Orange miss-rate std = 0.4896
Yellow miss-rate mean = 0.9737
Yellow miss-rate std = 0.0354


100%|██████████| 36/36 [25:16<00:00, 42.13s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0954
Mult IOU std = 0.1456
Bin IOU mean = 0.0509
Bin IOU std = 0.0675
Bin miss-rate mean = 0.9471
Bin miss-rate std = 0.0721
Red miss-rate mean = 0.8491
Red miss-rate std = 0.2074
Orange miss-rate mean = 0.3444
Orange miss-rate std = 0.4717
Yellow miss-rate mean = 0.9845
Yellow miss-rate std = 0.0263





In [13]:
compute_oneshot_knn(nn=11, weights='distance')

100%|██████████| 21/21 [10:27<00:00, 29.90s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1365
Mult IOU std = 0.1718
Bin IOU mean = 0.0268
Bin IOU std = 0.0429
Bin miss-rate mean = 0.9726
Bin miss-rate std = 0.0440
Red miss-rate mean = 0.8769
Red miss-rate std = 0.1981
Orange miss-rate mean = 0.2603
Orange miss-rate std = 0.4388
Yellow miss-rate mean = 0.9922
Yellow miss-rate std = 0.0122


100%|██████████| 15/15 [07:40<00:00, 30.68s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0354
Mult IOU std = 0.0626
Bin IOU mean = 0.0837
Bin IOU std = 0.0797
Bin miss-rate mean = 0.9126
Bin miss-rate std = 0.0864
Red miss-rate mean = 0.8289
Red miss-rate std = 0.1960
Orange miss-rate mean = 0.4552
Orange miss-rate std = 0.4913
Yellow miss-rate mean = 0.9738
Yellow miss-rate std = 0.0354


100%|██████████| 36/36 [18:07<00:00, 30.22s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0945
Mult IOU std = 0.1461
Bin IOU mean = 0.0505
Bin IOU std = 0.0671
Bin miss-rate mean = 0.9477
Bin miss-rate std = 0.0715
Red miss-rate mean = 0.8570
Red miss-rate std = 0.1986
Orange miss-rate mean = 0.3413
Orange miss-rate std = 0.4712
Yellow miss-rate mean = 0.9845
Yellow miss-rate std = 0.0263





#### 6. NN=13

In [14]:
compute_oneshot_knn(nn=13, weights='uniform')

100%|██████████| 21/21 [14:50<00:00, 42.41s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1409
Mult IOU std = 0.1817
Bin IOU mean = 0.0271
Bin IOU std = 0.0435
Bin miss-rate mean = 0.9723
Bin miss-rate std = 0.0447
Red miss-rate mean = 0.8552
Red miss-rate std = 0.2320
Orange miss-rate mean = 0.2562
Orange miss-rate std = 0.4365
Yellow miss-rate mean = 0.9920
Yellow miss-rate std = 0.0126


100%|██████████| 15/15 [10:47<00:00, 43.17s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0374
Mult IOU std = 0.0668
Bin IOU mean = 0.0846
Bin IOU std = 0.0800
Bin miss-rate mean = 0.9114
Bin miss-rate std = 0.0870
Red miss-rate mean = 0.8124
Red miss-rate std = 0.2114
Orange miss-rate mean = 0.4325
Orange miss-rate std = 0.4881
Yellow miss-rate mean = 0.9727
Yellow miss-rate std = 0.0377


100%|██████████| 36/36 [25:45<00:00, 42.92s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0979
Mult IOU std = 0.1541
Bin IOU mean = 0.0510
Bin IOU std = 0.0676
Bin miss-rate mean = 0.9470
Bin miss-rate std = 0.0722
Red miss-rate mean = 0.8375
Red miss-rate std = 0.2247
Orange miss-rate mean = 0.3295
Orange miss-rate std = 0.4668
Yellow miss-rate mean = 0.9840
Yellow miss-rate std = 0.0278





In [15]:
compute_oneshot_knn(nn=13, weights='distance')

100%|██████████| 21/21 [10:43<00:00, 30.65s/it]
  0%|          | 0/15 [00:00<?, ?it/s]


train | 224x224_anim10
Mult IOU mean = 0.1403
Mult IOU std = 0.1807
Bin IOU mean = 0.0268
Bin IOU std = 0.0430
Bin miss-rate mean = 0.9726
Bin miss-rate std = 0.0441
Red miss-rate mean = 0.8702
Red miss-rate std = 0.2079
Orange miss-rate mean = 0.2686
Orange miss-rate std = 0.4432
Yellow miss-rate mean = 0.9921
Yellow miss-rate std = 0.0124


100%|██████████| 15/15 [07:52<00:00, 31.48s/it]
  0%|          | 0/36 [00:00<?, ?it/s]


test | 224x224_anim10
Mult IOU mean = 0.0353
Mult IOU std = 0.0628
Bin IOU mean = 0.0837
Bin IOU std = 0.0797
Bin miss-rate mean = 0.9125
Bin miss-rate std = 0.0865
Red miss-rate mean = 0.8199
Red miss-rate std = 0.2027
Orange miss-rate mean = 0.4500
Orange miss-rate std = 0.4902
Yellow miss-rate mean = 0.9729
Yellow miss-rate std = 0.0376


100%|██████████| 36/36 [18:35<00:00, 30.98s/it]  


train_test | 224x224_anim10
Mult IOU mean = 0.0967
Mult IOU std = 0.1530
Bin IOU mean = 0.0505
Bin IOU std = 0.0671
Bin miss-rate mean = 0.9476
Bin miss-rate std = 0.0715
Red miss-rate mean = 0.8493
Red miss-rate std = 0.2072
Orange miss-rate mean = 0.3439
Orange miss-rate std = 0.4719
Yellow miss-rate mean = 0.9841
Yellow miss-rate std = 0.0277



