In [2]:
# Imports

import itertools
from itertools import permutations
from math import tanh
import os
import pickle
import platform
import random
from tkinter import Tk

from cvxopt import solvers, matrix
import math
from matplotlib import animation
from  matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np
import torch
import torchvision
from torchvision import transforms, models,datasets
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Extend width of Jupyter Notebook Cell to the size of browser
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# OS related settings
if platform.system() == 'Windows':
    print('Windows')
#     %matplotlib tk
    %matplotlib qt
elif platform.system() == 'Darwin':
    print('macOS')
    Tk().withdraw()
    %matplotlib osx
elif platform == 'linux' or platform == 'linux2':
    print('Linux')
# This line of "print" must exist right after %matplotlib command, otherwise JN will hang on the first import statement after this.
print('Interactive plot activated')

macOS
Interactive plot activated


In [3]:
# Functions


def cal_chi(fm, x):
    """
    Calculates ChI with given fuzzy measure and input
    
    :param fm: Fuzzy measure
    :param x: Input
    :return: Single value Chi output
    """
    pi_i = np.argsort(-x) + 1 # Arg sort of input, with the smallest index9 being 1
    ch = x[pi_i[0] - 1] * (fm[str(pi_i[:1])])
    for i in range(1, len(x)):
        latt_pti = np.sort(pi_i[:i+1])
        latt_ptimin1 = np.sort(pi_i[:i])
        ch = ch + x[pi_i[i] - 1] * (fm[str(latt_pti)] - fm[str(latt_ptimin1)])
    return ch


def get_cal_chi(fm):
    return lambda x: cal_chi(fm, x)



def get_keys_index(dim):
    """
    Sets up a dictionary for referencing FM.
    :return: The keys to the dictionary
    """
    vls = np.arange(1, dim + 1)
    Lattice = {}
    for i in range(1, dim + 1):
        A = np.array(list(itertools.combinations(vls, i)))
        for latt_pt in A:
            Lattice[str(latt_pt)] = 1
    return Lattice


def get_min_fm_target(dim):
    fm = get_keys_index(dim)
    for key in fm.keys():
        if len(key.split()) != dim:
            fm[key] = 0
        else:
            fm[key] = 1
    return fm
    
    
def get_max_fm_target(dim):
    fm = get_keys_index(dim)
    return fm


def get_mean_fm_target(dim):
    fm = get_keys_index(dim)
    for key in fm.keys():
        fm[key] = len(key.split()) / dim
    return fm


def get_gmean_fm_target(dim):
    fm = get_mean_fm_target(dim)
    return fm



def create_synthetic_data(num_samples=100, accuracies=[0.9, 0.6, 0.5]):
    label = np.random.randint(0, 2, num_samples)

    flip_ind = []
    for acc in accuracies:
        flip_ind.append(np.random.choice(range(num_samples), round((1-acc)*num_samples), replace=False))

    outputs = []

    for ind in flip_ind:
        output_bin = np.copy(label)
        output_bin[ind] = 1 - output_bin[ind]
        output = np.asarray([(random.random()+1)/2 if o_b == 1 else random.random()/2 for o_b in output_bin])
        
        outputs.append(output)
    outputs = np.asarray(outputs)
    
    return(label, outputs)
    

def test_accuracy(target, output):
    acc = np.sum(target == output.round()) / len(target)
    return acc

In [7]:
label, outputs = create_synthetic_data()

dim = len(outputs)

pA = get_keys_index(dim)

for key in pA.keys():
    key_int = np.asarray(key[1:-1].split()).astype(int) - 1
    output_coalition = np.mean(outputs[key_int, :], 0)
    acc_coalition = test_accuracy(label, output_coalition)
    pA[key] = acc_coalition
    
print(pA.values())
pA_values = list(pA.values())
a = np.amax(pA_values) - np.amin(pA_values)
pA_values = ((pA_values - np.amin(pA_values)) / a + 1) / 2
print(pA_values)

for i, key in enumerate(pA.keys()):
    pA[key] = pA_values[i]
    
print(pA.values())

dict_values([0.9, 0.6, 0.5, 0.69, 0.65, 0.52, 0.65])
[1.     0.625  0.5    0.7375 0.6875 0.525  0.6875]
dict_values([1.0, 0.625, 0.5, 0.7374999999999999, 0.6875, 0.525, 0.6875])
