# Globalne wyjaśnienia modelu

### Zuzanna Glinka
### Karol Degórski
### Adrian Kamiński

## Wybrany problem 

W ramach projektu zajmiemy się przede wszystkim analizą wyjaśniem związanych z kształtem.

Na poniższym ryskunku prezentujemy postawiony przez nas problem. Naszym celem jest wykrywanie czy duża figura (kwadrat, trójkąt lub okrąg) jest złożona z tylko i wyłącznie takich samych małych figur. Będzie to zatem problem klasyfikacji na 4 klasy:
- 0: koło
- 1: błędny
- 2: kwadrat
- 3: trójkąt

Przykłady wykorzystanych przez nas obrazów

![](https://raw.githubusercontent.com/adrian22311/2022L-WB-XIC/main/projects/Glinka_Degórski_Kamiński/images_example.png)

Wytrenowaliśmy kilka modeli (resnet18, resnet34, vgg11, efficientnet_b0), accuracy na poziomie 0.88-0.92

Wyjaśniamy model Resnet18. 

Ponieważ generowaliśmy dane samodzielnie mieliśmy dostęp do pełnej informacji jak każdy obraz powstał (pozycja, wielkość, kształt, kolor każdej z figur). Dlatego też udało nam się wygenerować pliki takie jak poniższy:

In [214]:
pd.read_csv('../data/details/test/circle/000000_circleMulti1.csv')

Unnamed: 0,x,y,shape,size,color,label
0,0.823762,0.555742,circle,0.066667,yellow,true circle
1,0.813852,0.631955,circle,0.05,red,true circle
2,0.725576,0.704774,circle,0.066667,red,true circle
3,0.600747,0.763981,circle,0.075,yellow,true circle
4,0.509717,0.778547,circle,0.05,red,true circle
5,0.43088,0.757554,circle,0.083333,red,true circle
6,0.368483,0.677482,circle,0.05,blue,true circle
7,0.351999,0.556972,circle,0.041667,red,true circle
8,0.353852,0.417086,circle,0.083333,red,true circle
9,0.378473,0.361297,circle,0.05,red,true circle


Następnie wybraliśmy po 100 obserwacji z każdej klasy które model dobrze sklasyfikował. I dla tych obrazów, wyznaczyliśmy, używająć shapa, wyjaśnienia każdego z pikseli i korzystając z powyższej tabeli dopasowaliśmy do każdej z figur, jaki miała ona wpływ.

Na tym etapie mieliśmy już csvki jak poniższa:

In [217]:
pd.read_csv('./csv/_0_0_importance.csv').head()

Unnamed: 0,x,y,shape,size,color,label,class,isTrue,mainShape,size_d,Importance,img_idx
0,0.725741,0.578848,circle,0.041667,yellow,true circle,0,True,circle,small,0.356176,185
1,0.669418,0.672296,circle,0.075,yellow,true circle,0,True,circle,big,0.89548,185
2,0.606276,0.78211,circle,0.041667,blue,true circle,0,True,circle,small,0.650871,185
3,0.441361,0.821049,circle,0.075,yellow,true circle,0,True,circle,big,0.342942,185
4,0.306318,0.747182,circle,0.041667,blue,true circle,0,True,circle,small,0.60895,185


### Analiza otrzymanych danych
### Agregacja i inne operacje

In [1]:
import pickle
import numpy as np
import pandas as pd

In [117]:
CLASSES = ['Circle', 'None', 'Square', 'Triangle']

In [218]:
def aggregate(data, by, agg_measures, columns='Importance', where=None):
    x = data if where is None else data[where]

    return x.groupby(by=['img_idx'] + by).agg(agg_measures).groupby(by=by).mean()[columns]

In [120]:
def aggregate_ratio(data, by, agg_measures, columns='Importance', where=None):

    return aggregate(
        data=data,
        by=by,
        agg_measures=agg_measures,
        columns=columns,
        where=where
    ) / aggregate(
        data=data,
        by=by,
        agg_measures=agg_measures,
        columns=columns,
        where=None
    )


In [139]:
def create_merged_results(datasets, aggregate_func, wheres=None, **kwargs):
    results = []
    if wheres is None:
        wheres = [None] * len(datasets)
    for data, where in zip(datasets, wheres):
        x = aggregate_func(data=data, **kwargs, where=where)
        x['class'] = CLASSES[data['class'][0]]
        results.append(x)
        
    con = pd.concat(results)
    return con.set_index(['class', con.index]).sort_index()

##### Wczytywanie

In [220]:
_0 = pd.read_csv('./csv/_0_0_importance.csv')
_1 = pd.read_csv('./csv/_1_1_importane.csv')
_2 = pd.read_csv('./csv/_2_2_importane.csv')
_3 = pd.read_csv('./csv/_3_3_importane.csv')

datasets = [_0,_1,_2,_3]

### Średni wpływ ze względu na kształt

In [144]:
create_merged_results(datasets=[_0,_1,_2,_3], aggregate_func=aggregate, by=['shape'], agg_measures=['min','max', 'mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,min,max,mean
class,shape,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Circle,circle,-0.553896,0.83018,0.16914
Circle,square,-0.411557,0.471737,0.007745
Circle,triangle,-0.221118,0.188074,-0.01854
,circle,-0.629873,0.488206,-0.096776
,square,-0.508448,0.737067,0.114
,triangle,-0.198143,0.321154,0.068725
Square,circle,-0.365289,0.441253,0.018298
Square,square,-0.599299,0.881642,0.155575
Square,triangle,-0.141003,0.186853,0.027493
Triangle,circle,-0.40116,0.493134,0.028857


### Średni wpływ ze względu na wielkość i kształt

In [128]:
create_merged_results(datasets=[_0,_1,_2,_3], aggregate_func=aggregate, by=['size_d'], agg_measures=['min','max', 'mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,min,max,mean
class,size_d,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Circle,big,-0.391055,0.75566,0.194732
Circle,medium,-0.446473,0.659777,0.117306
Circle,small,-0.423116,0.685096,0.123463
,big,-0.575319,0.636644,0.023284
,medium,-0.56024,0.554009,0.007042
,small,-0.530961,0.569892,0.022424
Square,big,-0.45584,0.765246,0.155293
Square,medium,-0.455885,0.689616,0.131739
Square,small,-0.425272,0.703494,0.133727
Triangle,big,-0.405528,0.536682,0.063085


In [151]:
create_merged_results(datasets=datasets, aggregate_func=aggregate, by=['shape', 'size_d'], agg_measures=['min','max', 'mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,min,max,mean
class,shape,size_d,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Circle,circle,big,-0.273843,0.692455,0.231685
Circle,circle,medium,-0.33914,0.614302,0.144424
Circle,circle,small,-0.354063,0.620969,0.145262
Circle,square,big,-0.257279,0.324376,0.030317
Circle,square,medium,-0.265646,0.123521,-0.073533
Circle,square,small,-0.112815,0.250156,0.062198
Circle,triangle,big,-0.113889,0.107939,-0.003676
Circle,triangle,medium,-0.071156,0.132011,0.027457
Circle,triangle,small,-0.155361,0.09432,-0.023595
,circle,big,-0.405402,0.213369,-0.098006


### Ile procent konkretnych kształtów ma dodatni wpływ na predykcję

In [171]:
create_merged_results(datasets=datasets, aggregate_func=aggregate_ratio,
                      by=['size_d'], agg_measures=['count'],
                      wheres=[data.Importance > 0 for data in datasets])

Unnamed: 0_level_0,Unnamed: 1_level_0,count
class,size_d,Unnamed: 2_level_1
Circle,big,0.593646
Circle,medium,0.534188
Circle,small,0.564732
,big,0.519052
,medium,0.522624
,small,0.541614
Square,big,0.591682
Square,medium,0.575576
Square,small,0.569511
Triangle,big,0.588235


### Jak to wygląda jak podzielimy to jeszcze ze względu na kształt

In [172]:
create_merged_results(datasets=datasets, aggregate_func=aggregate_ratio,
                      by=['size_d', 'shape'], agg_measures=['count'],
                      wheres=[data.Importance > 0 for data in datasets])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,count
class,size_d,shape,Unnamed: 3_level_1
Circle,big,circle,0.643559
Circle,big,square,0.669261
Circle,big,triangle,0.662722
Circle,medium,circle,0.586562
Circle,medium,square,0.70891
Circle,medium,triangle,0.736897
Circle,small,circle,0.612014
Circle,small,square,0.726496
Circle,small,triangle,0.757551
,big,circle,0.580645


### Jaki jest stostunek figur na obrazku które mają pozywyny wpływ, jaki on jest?

In [198]:
create_merged_results(datasets=datasets, aggregate_func=aggregate_ratio,
                      by=['isTrue'], agg_measures=['count'],
                      wheres=[data.Importance > 0 for data in datasets])

Unnamed: 0_level_0,Unnamed: 1_level_0,count
class,isTrue,Unnamed: 2_level_1
Circle,False,0.505162
Circle,True,0.615385
,False,0.517002
Square,False,0.532975
Square,True,0.612472
Triangle,False,0.538048
Triangle,True,0.600931


In [199]:
create_merged_results(datasets=datasets, aggregate_func=aggregate,
                      by=['isTrue'], agg_measures=['min', 'max', 'mean'],
                      wheres=[data.Importance > 0 for data in datasets])

Unnamed: 0_level_0,Unnamed: 1_level_0,min,max,mean
class,isTrue,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Circle,False,0.054737,0.741305,0.346495
Circle,True,0.093478,0.771764,0.416223
,False,0.041684,0.799374,0.344201
Square,False,0.057037,0.683448,0.320224
Square,True,0.070007,0.869354,0.414985
Triangle,False,0.097046,0.73127,0.375173
Triangle,True,0.083282,0.492244,0.273749


### Jaki jest stostunek figur na obrazku które mają negatywny wpływ, jaki on jest?

In [201]:
create_merged_results(datasets=datasets, aggregate_func=aggregate_ratio,
                      by=['isTrue'], agg_measures=['count'],
                      wheres=[data.Importance < 0 for data in datasets])

Unnamed: 0_level_0,Unnamed: 1_level_0,count
class,isTrue,Unnamed: 2_level_1
Circle,False,0.494838
Circle,True,0.42735
,False,0.482998
Square,False,0.473422
Square,True,0.412264
Triangle,False,0.461952
Triangle,True,0.580965


In [203]:
create_merged_results(datasets=datasets, aggregate_func=aggregate,
                      by=['isTrue'], agg_measures=['min', 'max', 'mean'],
                      wheres=[data.Importance < 0 for data in datasets])

Unnamed: 0_level_0,Unnamed: 1_level_0,min,max,mean
class,isTrue,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Circle,False,-0.644269,-0.065046,-0.30919
Circle,True,-0.572265,-0.076269,-0.28688
,False,-0.766412,-0.064808,-0.344842
Square,False,-0.599331,-0.048182,-0.276598
Square,True,-0.586475,-0.078054,-0.278764
Triangle,False,-0.667575,-0.085036,-0.331044
Triangle,True,-0.424785,-0.059857,-0.220793


### Średni wpływ koloru konkretnych figur na predycję

In [206]:
create_merged_results(datasets=datasets, aggregate_func=aggregate,
                      by=['color'], agg_measures=['mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,mean
class,color,Unnamed: 2_level_1
Circle,blue,0.173565
Circle,red,0.149544
Circle,yellow,0.118836
,blue,0.035612
,red,-0.004836
,yellow,0.021891
Square,blue,0.100305
Square,red,0.148393
Square,yellow,0.173712
Triangle,blue,0.075439


### Średni wpływ koloru i kształtu konkretnych figur na predycję

In [207]:
create_merged_results(datasets=datasets, aggregate_func=aggregate,
                      by=['color', 'shape'], agg_measures=['mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean
class,color,shape,Unnamed: 3_level_1
Circle,blue,circle,0.198428
Circle,blue,square,0.037948
Circle,blue,triangle,0.012097
Circle,red,circle,0.181888
Circle,red,square,-0.019762
Circle,red,triangle,-0.017217
Circle,yellow,circle,0.132266
Circle,yellow,square,0.009566
Circle,yellow,triangle,-0.070852
,blue,circle,-0.062155
