## import library

In [1]:
import sys
sys.path.append('../../')

In [2]:
import pandas as pd
import numpy as np
import os
import math
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.datasets import load_svmlight_file
import torch
from mlinterpreter.mmd.mmd_critic import Dataset, select_prototypes, select_criticisms
from tqdm import tqdm, tqdm_notebook
from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing

In [3]:
def load_data(path, target):
    df = pd.read_csv(path)
    df_origin = df.copy()
    
    num_cols = []
    cat_cols = []

    for col in df.columns:
        if col in [target]:
            continue
        if (df[col].dtype == 'object'):
            cat_cols.append(col)
        else:
            num_cols.append(col)
    
    for col in num_cols:
        scaler = StandardScaler()
        scaler.fit(df[col].values.reshape(-1, 1))

        df[col] = scaler.transform(df[col].values.reshape(-1, 1))
    
    
    for f in cat_cols:
        lbl = preprocessing.LabelEncoder()
        print(f, df[f].nunique())
        lbl.fit(df[f])
        df[f] = lbl.transform(list(df[f].astype(str)))
    
    not_used = [target]
    used_features = [x for x in df.columns if x not in not_used]
    X = df[used_features]
    y = df[target]
    
    X = torch.tensor(np.array(X), dtype=torch.float)
    y = torch.tensor(np.array(y), dtype=torch.long)
    
    return X, y, df_origin

## parameter setting

In [4]:
gamma = 0.026

num_prototypes = 32
num_criticisms = 10

kernel_type = 'local'
# kernel_type = 'global'

# regularizer = None
regularizer = 'logdet'
# regularizer = 'iterative'

path = '../../demo_data/adult.csv'
target = 'income'

## model

In [5]:
X_train, y_train, df_origin = load_data(path, target)

d_train = Dataset(X_train, y_train)
if kernel_type == 'global':
    d_train.compute_rbf_kernel(gamma)
elif kernel_type == 'local':
    d_train.compute_local_rbf_kernel(gamma)
else:
    raise KeyError('kernel_type must be either "global" or "local"')
print('Done.', flush=True)

# Prototypes
if num_prototypes > 0:
    print('Computing prototypes...', end='', flush=True)
    prototype_indices = select_prototypes(d_train.K, num_prototypes)

    prototypes = d_train.X[prototype_indices]
    prototype_labels = d_train.y[prototype_indices]

    sorted_by_y_indices = prototype_labels.argsort()
    prototypes_sorted = prototypes[sorted_by_y_indices]
    prototype_labels = prototype_labels[sorted_by_y_indices]
    print('Done.', flush=True)
    print(prototype_indices.sort()[0].tolist())

    # Criticisms
    if num_criticisms > 0:
        print('Computing criticisms...', end='', flush=True)
        criticism_indices = select_criticisms(d_train.K, prototype_indices, num_criticisms, regularizer)

        criticisms = d_train.X[criticism_indices]
        criticism_labels = d_train.y[criticism_indices]

        sorted_by_y_indices = criticism_labels.argsort()
        criticisms_sorted = criticisms[sorted_by_y_indices]
        criticism_labels = criticism_labels[sorted_by_y_indices]
        print('Done.', flush=True)
        print(criticism_indices.sort()[0].tolist())

workclass 9
education 16
marital.status 7
occupation 15
relationship 6
race 5
sex 2
native.country 42
Done.
Computing prototypes...Done.
[193, 2147, 3862, 4426, 4908, 5611, 6808, 6899, 7814, 8247, 8320, 9393, 10440, 12257, 12778, 16270, 16414, 16520, 17719, 20478, 20970, 21650, 22147, 22856, 24101, 25749, 26317, 27196, 27840, 28968, 30736, 32004]
Computing criticisms...Done.
[2094, 6446, 8370, 10324, 15777, 23747, 24873, 26398, 26434, 27814]


## prototype

In [6]:
df_origin.loc[prototype_indices, :]

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
12257,23,Private,173851,Some-college,10,Never-married,Craft-repair,Not-in-family,White,Male,0,0,40,United-States,0
30736,33,Self-emp-not-inc,99761,Bachelors,13,Never-married,Other-service,Not-in-family,White,Female,0,0,15,United-States,0
5611,30,Private,27207,11th,7,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,45,United-States,0
16270,65,?,105017,Bachelors,13,Married-civ-spouse,?,Husband,White,Male,0,0,40,United-States,0
8247,46,Private,34377,Assoc-acdm,12,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,United-States,1
2147,38,Private,217349,Assoc-voc,11,Divorced,Prof-specialty,Not-in-family,White,Female,14344,0,40,United-States,1
6899,46,Local-gov,102076,HS-grad,9,Divorced,Adm-clerical,Unmarried,White,Female,0,0,25,United-States,0
32004,25,Private,74977,Some-college,10,Married-civ-spouse,Adm-clerical,Wife,White,Female,0,0,40,United-States,0
22856,29,Self-emp-inc,266070,Bachelors,13,Divorced,Prof-specialty,Not-in-family,White,Female,0,0,80,United-States,0
28968,41,Self-emp-not-inc,117012,Masters,14,Never-married,Exec-managerial,Not-in-family,White,Female,0,0,55,United-States,0


## criticism

In [7]:
df_origin.loc[criticism_indices, :]

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
10324,28,Private,259840,Bachelors,13,Married-civ-spouse,Sales,Husband,White,Male,0,0,60,United-States,1
2094,46,Self-emp-inc,192779,Prof-school,15,Married-civ-spouse,Prof-specialty,Husband,White,Male,15024,0,60,United-States,1
8370,24,Private,184400,HS-grad,9,Never-married,Transport-moving,Own-child,Asian-Pac-Islander,Male,0,0,30,?,0
6446,29,Self-emp-not-inc,29616,HS-grad,9,Never-married,Farming-fishing,Own-child,White,Male,0,0,65,United-States,0
27814,35,Federal-gov,182898,HS-grad,9,Married-civ-spouse,Adm-clerical,Husband,White,Male,0,0,40,United-States,0
23747,46,Private,165953,Bachelors,13,Married-civ-spouse,Adm-clerical,Husband,White,Male,0,0,45,United-States,1
26398,19,Self-emp-not-inc,137578,HS-grad,9,Never-married,Other-service,Own-child,White,Male,0,0,53,United-States,0
15777,23,Private,388811,HS-grad,9,Never-married,Adm-clerical,Own-child,White,Female,0,0,40,United-States,0
26434,30,Private,251411,Some-college,10,Divorced,Adm-clerical,Unmarried,White,Female,0,0,40,United-States,0
24873,41,Private,213019,Assoc-voc,11,Married-civ-spouse,Tech-support,Wife,White,Female,0,0,38,United-States,1


Reference: [MMD-critic](https://github.com/maxidl/MMD-critic)