In [1]:
# Imports

import itertools
from itertools import permutations
import os
import pickle
import platform
import random
from tkinter import Tk

from cvxopt import solvers, matrix
import math
from matplotlib import animation
from  matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np
import torch
import torchvision
from torchvision import transforms, models,datasets
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Extend width of Jupyter Notebook Cell to the size of browser
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# OS related settings
if platform.system() == 'Windows':
    print('Windows')
#     %matplotlib tk
    %matplotlib qt
elif platform.system() == 'Darwin':
    print('macOS')
    Tk().withdraw()
    %matplotlib osx
elif platform == 'linux' or platform == 'linux2':
    print('Linux')
# This line of "print" must exist right after %matplotlib command, otherwise JN will hang on the first import statement after this.
print('Interactive plot activated')

macOS
Interactive plot activated


In [2]:
# ChINN

# Convert decimal to binary string
def sources_and_subsets_nodes(N):
    str1 = "{0:{fill}"+str(N)+"b}"
    a = []
    for i in range(1,2**N):
        a.append(str1.format(i, fill='0'))

    sourcesInNode = []
    sourcesNotInNode = []
    subset = []
    sourceList = list(range(N))
    # find subset nodes of a node
    def node_subset(node, sourcesInNodes):
        return [node - 2**(i) for i in sourcesInNodes]
    
    # convert binary encoded string to integer list
    def string_to_integer_array(s, ch):
        N = len(s) 
        return [(N - i - 1) for i, ltr in enumerate(s) if ltr == ch]
    
    for j in range(len(a)):
        # index from right to left
        idxLR = string_to_integer_array(a[j],'1')
        sourcesInNode.append(idxLR)  
        sourcesNotInNode.append(list(set(sourceList) - set(idxLR)))
        subset.append(node_subset(j,idxLR))

    return sourcesInNode, subset


def subset_to_indices(indices):
    return [i for i in indices]

class Choquet_Integral_NN(torch.nn.Module):
    
    def __init__(self, N_in, N_out):
        super(Choquet_Integral_NN,self).__init__()
        self.N_in = N_in
        self.N_out = N_out
        self.nVars = 2**self.N_in - 2
        
        # The FM is initialized with mean
        dummy = (1./self.N_in) * torch.ones((self.nVars, self.N_out), requires_grad=True)
#        self.vars = torch.nn.Parameter( torch.Tensor(self.nVars,N_out))
        self.vars = torch.nn.Parameter(dummy)
        
        # following function uses numpy vs pytorch
        self.sourcesInNode, self.subset = sources_and_subsets_nodes(self.N_in)
        
        self.sourcesInNode = [torch.tensor(x) for x in self.sourcesInNode]
        self.subset = [torch.tensor(x) for x in self.subset]
        
    def forward(self,inputs):    
        self.FM = self.chi_nn_vars(self.vars)
        sortInputs, sortInd = torch.sort(inputs,1, True)
        M, N = inputs.size()
        sortInputs = torch.cat((sortInputs, torch.zeros(M,1)), 1)
        sortInputs = sortInputs[:,:-1] -  sortInputs[:,1:]
        
        out = torch.cumsum(torch.pow(2,sortInd),1) - torch.ones(1, dtype=torch.int64)
        
        data = torch.zeros((M,self.nVars+1))
        
        for i in range(M):
            data[i,out[i,:]] = sortInputs[i,:] 
        
        
        ChI = torch.matmul(data,self.FM)
            
        return ChI
    
    # Converts NN-vars to FM vars
    def chi_nn_vars(self, chi_vars):
#        nVars,_ = chi_vars.size()
        chi_vars = torch.abs(chi_vars)
        #        nInputs = inputs.get_shape().as_list()[1]
        
        FM = chi_vars[None, 0,:]
        for i in range(1,self.nVars):
            indices = subset_to_indices(self.subset[i])
            if (len(indices) == 1):
                FM = torch.cat((FM,chi_vars[None,i,:]),0)
            else:
                #         ss=tf.gather_nd(variables, [[1],[2]])
                maxVal,_ = torch.max(FM[indices,:],0)
                temp = torch.add(maxVal,chi_vars[i,:])
                FM = torch.cat((FM,temp[None,:]),0)
              
        FM = torch.cat([FM, torch.ones((1,self.N_out))],0)
        FM = torch.min(FM, torch.ones(1))  
        
        return FM
    



In [3]:
# Functions

def create_dataset(dim, all_perms, num_per_perm, superset_factor):
    """
    Create a dataset with all possible permutation, with each permutation having the same number of samples.
    
    :param dim: Dimension of data sample
    :param all_perms: A list of permutation. 
                      Use as an input so that the algorithm creates data for different permutations in the order assigned.
    :param num_per_perm: Number of data samples for each permutation
    :param superset_factor: Data for each permutation are pulled from a randomly generated dataset. 
                            To ensure that each permutation gets at least #num_per_perm# data samples,
                            create a dataset #superset_factor* times (normally 3 will be enough) bigger than the dataset wanted.
    """
    # Every permutation gets the same number of train/test data samples,
    # To ensure that, calculate the number of data needed in total, 
    # and generate a super dataset that is multiple times bigger.
    num = math.factorial(dim) * num_per_perm
    # Create superset
    data_superset = np.random.rand(dim, num*superset_factor)
    # Get permutation of each data sample
    data_perms = np.argsort(data_superset, 0)
    # N! possible permutations
#     all_perms = list(permutations(list(range(dim))))
    # Group data sample according to its permutation
    data_idx_superset_by_perm = []
    
    for i, current_perm in enumerate(all_perms):
        # Get index of data sample of certain permutation and save to list
        temp = np.where(data_perms[0, :]==current_perm[0])
        for idx, p in enumerate(current_perm):
            temp = np.intersect1d(temp, np.where(data_perms[idx, :]==p))
        if temp.size < num_per_perm:
            print('Current permutation doesn\'t have sufficient number of samples. Please regenerate!')
            exit()
        data_idx_superset_by_perm.append(temp)
    
    # Every permutation gets the same number of train/test data samples,
    # Data is randomly pull from superset each epoch
    data_idx_by_perm = []
    for i in range(len(all_perms)):
        temp = data_idx_superset_by_perm[i]
        random.shuffle(temp)
        data_idx_by_perm.append(temp[0:num_per_perm])
        
    return data_superset, data_idx_superset_by_perm, data_idx_by_perm


def cal_chi(fm, x):
    """
    Calculates ChI with given fuzzy measure and input
    
    :param fm: Fuzzy measure
    :param x: Input
    :return: Single value Chi output
    """
    pi_i = np.argsort(-x) + 1 # Arg sort of input, with the smallest index9 being 1
    ch = x[pi_i[0] - 1] * (fm[str(pi_i[:1])])
    for i in range(1, len(x)):
        latt_pti = np.sort(pi_i[:i+1])
        latt_ptimin1 = np.sort(pi_i[:i])
        ch = ch + x[pi_i[i] - 1] * (fm[str(latt_pti)] - fm[str(latt_ptimin1)])
    return ch


def get_cal_chi(fm):
    return lambda x: cal_chi(fm, x)


def train_chinn(chinn, lr, criterion, optimizer, num_epoch, train_d, train_label):
    """
    Train nn for fuzzy measure
    """
    for epoch in range(num_epoch):
        # Forward pass: Compute predicted y by passing x to the model
        y_pred = chinn(train_d.transpose(0, 1))
        # Compute the loss
        loss = criterion(y_pred, train_label)
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

def gmean(dim):
    return lambda x, d: torch.pow(torch.prod(x, d), 1/dim)


def get_keys_index(dim):
    """
    Sets up a dictionary for referencing FM.
    :return: The keys to the dictionary
    """
    vls = np.arange(1, dim + 1)
    Lattice = {}
    for i in range(1, dim + 1):
        A = np.array(list(itertools.combinations(vls, i)))
        for latt_pt in A:
            Lattice[str(latt_pt)] = 1
    return Lattice


def get_min_fm_target(dim):
    fm = get_keys_index(dim)
    for key in fm.keys():
        if len(key.split()) != dim:
            fm[key] = 0
        else:
            fm[key] = 1
    return fm
    
    
def get_max_fm_target(dim):
    fm = get_keys_index(dim)
    return fm


def get_mean_fm_target(dim):
    fm = get_keys_index(dim)
    for key in fm.keys():
        fm[key] = len(key.split()) / dim
    return fm


def get_gmean_fm_target(dim):
    fm = get_mean_fm_target(dim)
    return fm

In [4]:
# Run

# Parameters
num_repetition = 100
num_epoch = 100
dim_list = list(range(3, 4))
num_per_perm_train = 10 # Every permutation gets the same number of train/test data samples
num_per_perm_test = 10 # Every permutation gets the same number of train/test data samples
superset_factor = 4 # Create a superset # times larger than what is needed to ensure each class would have sufficient number of samples
eva_funcs = [torch.amin, torch.amax, torch.mean, gmean(2)] # List of evaluation functions, initialize with dim=2

MSEs_seen_by_dim = []
MSEs_unseen_by_dim = []

FM_by_dim = []

output_dir = 'output/'

# Train&Test for dim = 3 to 8
for dim in dim_list:
    eva_funcs[-1] = gmean(dim) # List of evaluation functions
    
    all_perms = list(permutations(list(range(dim)))) # N! possible permutations

    # When the # of possible permutations exceed certain number (in here 5!), 
    #instead of feeding only one more permutation a time, feed more.
    train_group_num_limit = math.factorial(5)
    if len(all_perms) > train_group_num_limit:
        step = int(len(all_perms) / train_group_num_limit)
    else:
        step = 1
    
    # Mean Squared Error for each evaluation function, for each percentage, for each repetition, of all test samples, for both seen and unseen data.
    MSEs_seen = np.zeros((len(eva_funcs), len(range(step-1, len(all_perms), step)), num_repetition))
    MSEs_unseen = np.zeros((len(eva_funcs), len(range(step-1, len(all_perms), step))-1, num_repetition))
    # Record FM after train session with both seen and unseen data pattern
    FM = np.zeros((len(eva_funcs), len(range(step-1, len(all_perms), step)), num_repetition, 2**dim-1))
    
    for rep in range(num_repetition):
        print('Repetition:', rep+1)
        random.shuffle(all_perms)
        train_data_superset, _, train_idx_by_perm = create_dataset(dim, all_perms, num_per_perm_train`, superset_factor)
        test_data_superset, _, test_idx_by_perm = create_dataset(dim, all_perms, num_per_perm_test, superset_factor)
        
        for perc_idx, perc in enumerate(tqdm(range(step-1, len(all_perms), step))):
            # Find index of train/test sample in superset and shuffle
            train_idx = np.concatenate(train_idx_by_perm[0:perc+1])
            np.random.shuffle(train_idx)
            test_idx = np.concatenate(test_idx_by_perm[0:perc+1])
            # Find data sample through index and convert from numpy array to torch tensor
            train_d = train_data_superset[:, train_idx]
            test_d = test_data_superset[:, test_idx]
            train_d = torch.tensor(train_d, dtype=torch.float)
            test_d = torch.tensor(test_d, dtype=torch.float)
            # Define unseen test data samples when the train data doesn't cover 100% of the permutation
            if perc < len(all_perms)-1:
                test_idx_unseen = np.concatenate(test_idx_by_perm[perc+1:])
                test_d_unseen = test_data_superset[:, test_idx_unseen]
                test_d_unseen = torch.tensor(test_d_unseen, dtype=torch.float)
            else:
                test_d_unseen = []
            
            # Define subsets of 'X', or keys for fuzzy measure. Like '1 2' or '1 3 4 5' for g(x1, x2) or g(x1, x3, x4, x5)
            sourcesInNode, subset = sources_and_subsets_nodes(dim)
            keys = [str(np.sort(i)+1) for i in sourcesInNode]
            
            for eva_idx, eva_func in enumerate(eva_funcs):
                # Calculate label with given evaluation function
                train_label = eva_func(train_d, 0)
                test_label = eva_func(test_d, 0)
                train_label.unsqueeze_(1)
                
                
                # Initialize ChINN
                chinn = Choquet_Integral_NN(dim, 1)
                # Parameters for training NN
                lr = 0.05 # Learning rate
                criterion = torch.nn.MSELoss(reduction='mean')
                optimizer = torch.optim.SGD(chinn.parameters(), lr=lr)
                # Train 
                train_chinn(chinn, lr, criterion, optimizer, num_epoch, train_d, train_label)
                # Get fuzzy measure learned
                FM_learned = (chinn.chi_nn_vars(chinn.vars).cpu()).detach().numpy()
                fm_dict_binary = dict(zip(keys, FM_learned[:,0]))
                fm_dict_lexicographic = get_keys_index(dim)
                for key in fm_dict_lexicographic.keys():
                    fm_dict_lexicographic[key] = fm_dict_binary[key]
                FM[eva_idx, perc_idx, rep, :] = np.asarray(list(fm_dict_lexicographic.values()))
                # Calculate result from integral with test data
                test_output = np.apply_along_axis(get_cal_chi(fm_dict_lexicographic), 0, test_d)
                MSE = ((test_output - test_label.numpy())**2).mean()
                MSEs_seen[eva_idx, perc_idx, rep] = MSE
                # Calculate result from integral with test data - unseen
                if perc < len(all_perms)-1:
                    test_label_unseen = eva_func(test_d_unseen, 0)
                    test_out_unseen = np.apply_along_axis(get_cal_chi(fm_dict_lexicographic), 0, test_d_unseen)
                    MSEs_unseen[eva_idx, perc_idx, rep] = ((test_out_unseen - test_label_unseen.numpy())**2).mean()
    FM_by_dim.append(FM)
    MSEs_seen_by_dim.append(MSEs_seen)
    MSEs_unseen_by_dim.append(MSEs_unseen)

with open(output_dir + 'ChINN_saved_file', 'wb') as f:
    pickle.dump(FM_by_dim, f)
    pickle.dump(MSEs_seen_by_dim, f)
    pickle.dump(MSEs_unseen_by_dim, f)

Repetition: 1


100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Repetition: 2


100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Repetition: 3


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 4


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]


Repetition: 5


100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Repetition: 6


100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Repetition: 7


100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Repetition: 8


100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Repetition: 9


100%|██████████| 6/6 [00:02<00:00,  2.64it/s]


Repetition: 10


100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Repetition: 11


100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Repetition: 12


100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Repetition: 13


100%|██████████| 6/6 [00:02<00:00,  2.52it/s]


Repetition: 14


100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Repetition: 15


100%|██████████| 6/6 [00:02<00:00,  2.56it/s]


Repetition: 16


100%|██████████| 6/6 [00:02<00:00,  2.51it/s]


Repetition: 17


100%|██████████| 6/6 [00:02<00:00,  2.46it/s]


Repetition: 18


100%|██████████| 6/6 [00:02<00:00,  2.43it/s]


Repetition: 19


100%|██████████| 6/6 [00:02<00:00,  2.47it/s]


Repetition: 20


100%|██████████| 6/6 [00:02<00:00,  2.50it/s]


Repetition: 21


100%|██████████| 6/6 [00:02<00:00,  2.50it/s]


Repetition: 22


100%|██████████| 6/6 [00:02<00:00,  2.46it/s]


Repetition: 23


100%|██████████| 6/6 [00:02<00:00,  2.49it/s]


Repetition: 24


100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Repetition: 25


100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Repetition: 26


100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Repetition: 27


100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Repetition: 28


100%|██████████| 6/6 [00:02<00:00,  2.56it/s]


Repetition: 29


100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Repetition: 30


100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Repetition: 31


100%|██████████| 6/6 [00:02<00:00,  2.51it/s]


Repetition: 32


100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Repetition: 33


100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Repetition: 34


100%|██████████| 6/6 [00:02<00:00,  2.50it/s]


Repetition: 35


100%|██████████| 6/6 [00:02<00:00,  2.47it/s]


Repetition: 36


100%|██████████| 6/6 [00:02<00:00,  2.52it/s]


Repetition: 37


100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Repetition: 38


100%|██████████| 6/6 [00:02<00:00,  2.47it/s]


Repetition: 39


100%|██████████| 6/6 [00:02<00:00,  2.51it/s]


Repetition: 40


100%|██████████| 6/6 [00:02<00:00,  2.51it/s]


Repetition: 41


100%|██████████| 6/6 [00:02<00:00,  2.56it/s]


Repetition: 42


100%|██████████| 6/6 [00:02<00:00,  2.56it/s]


Repetition: 43


100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Repetition: 44


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 45


100%|██████████| 6/6 [00:02<00:00,  2.50it/s]


Repetition: 46


100%|██████████| 6/6 [00:02<00:00,  2.21it/s]


Repetition: 47


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]


Repetition: 48


100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Repetition: 49


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 50


100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Repetition: 51


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 52


100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Repetition: 53


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 54


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 55


100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Repetition: 56


100%|██████████| 6/6 [00:02<00:00,  2.51it/s]


Repetition: 57


100%|██████████| 6/6 [00:02<00:00,  2.54it/s]


Repetition: 58


100%|██████████| 6/6 [00:02<00:00,  2.52it/s]


Repetition: 59


100%|██████████| 6/6 [00:02<00:00,  2.47it/s]


Repetition: 60


100%|██████████| 6/6 [00:02<00:00,  2.53it/s]


Repetition: 61


100%|██████████| 6/6 [00:02<00:00,  2.64it/s]


Repetition: 62


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 63


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 64


100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Repetition: 65


100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Repetition: 66


100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Repetition: 67


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 68


100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Repetition: 69


100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Repetition: 70


100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Repetition: 71


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 72


100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Repetition: 73


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 74


100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Repetition: 75


100%|██████████| 6/6 [00:02<00:00,  2.69it/s]


Repetition: 76


100%|██████████| 6/6 [00:02<00:00,  2.68it/s]


Repetition: 77


100%|██████████| 6/6 [00:02<00:00,  2.70it/s]


Repetition: 78


100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Repetition: 79


100%|██████████| 6/6 [00:02<00:00,  2.70it/s]


Repetition: 80


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 81


100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Repetition: 82


100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Repetition: 83


100%|██████████| 6/6 [00:02<00:00,  2.56it/s]


Repetition: 84


100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Repetition: 85


100%|██████████| 6/6 [00:02<00:00,  2.68it/s]


Repetition: 86


100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Repetition: 87


100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Repetition: 88


100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Repetition: 89


100%|██████████| 6/6 [00:02<00:00,  2.52it/s]


Repetition: 90


100%|██████████| 6/6 [00:02<00:00,  2.64it/s]


Repetition: 91


100%|██████████| 6/6 [00:02<00:00,  2.54it/s]


Repetition: 92


100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Repetition: 93


100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Repetition: 94


100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Repetition: 95


100%|██████████| 6/6 [00:02<00:00,  2.68it/s]


Repetition: 96


100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Repetition: 97


100%|██████████| 6/6 [00:02<00:00,  2.69it/s]


Repetition: 98


100%|██████████| 6/6 [00:02<00:00,  2.69it/s]


Repetition: 99


100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Repetition: 100


100%|██████████| 6/6 [00:02<00:00,  2.69it/s]


In [5]:
# Plot for seen data
output_dir = 'output/'
dim_list = list(range(3, 4))
eva_funcs = [torch.amin, torch.amax, torch.mean, gmean(2)]

with open(output_dir + 'ChINN_saved_file', 'rb') as f:
    FM_by_dim = pickle.load(f)
    MSEs_seen_by_dim = pickle.load(f)
    MSEs_unseen_by_dim = pickle.load(f)

print
for dim, MSEs_seen in zip(dim_list, MSEs_seen_by_dim):
    plot_mean = np.mean(MSEs_seen, 2)
    plot_max = np.max(MSEs_seen, 2)
    plot_min = np.min(MSEs_seen, 2)

    num_perm = math.factorial(dim)

    fig, ax = plt.subplots()

    x = (np.asarray(list(range(1, np.size(MSEs_seen, 1)+1)))) / np.size(MSEs_seen, 1)
    plt.plot(x, plot_mean.transpose())
    ax.set_title('MSE (Data with seen pattern)')
    ax.legend(['Min', 'Max', 'Mean', 'Geometric Mean'])
    ax.set_xlabel('Percentage of Seen Data')
    ax.set_ylabel('MSEs avg')
    ax.xaxis.set_major_formatter(FuncFormatter('{0:.0%}'.format))

    for i in range(len(eva_funcs)):
        plt.fill_between(x, plot_min[i, :], plot_max[i, :], alpha=0.1)
    
    plt.savefig(output_dir + 'ChINN-' + str(dim) + '-MSE seen.png')
    
    
    
# Plot for unseen data

for dim, MSEs_unseen in zip(dim_list, MSEs_unseen_by_dim):
    plot_mean = np.mean(MSEs_unseen, 2)
    plot_max = np.max(MSEs_unseen, 2)
    plot_min = np.min(MSEs_unseen, 2)

    num_perm = math.factorial(dim)

    fig, ax = plt.subplots()

    x = (np.asarray(list(range(1, np.size(MSEs_unseen, 1)+1)))) / (np.size(MSEs_unseen, 1)+1)
    plt.plot(x, plot_mean.transpose())
    ax.set_title('MSE (Data with unseen pattern)')
    ax.legend(['Min', 'Max', 'Mean', 'Geometric Mean'])
    ax.set_xlabel('Percentage of Seen Data')
    ax.set_ylabel('MSEs avg')
    ax.xaxis.set_major_formatter(FuncFormatter('{0:.0%}'.format))

    for i in range(len(eva_funcs)):
        plt.fill_between(x, plot_min[i, :], plot_max[i, :], alpha=0.1)
        
    plt.savefig(output_dir + 'ChINN-' + str(dim) + '-MSE unseen.png')

In [6]:

print('test')
fm_targets = [get_min_fm_target, get_max_fm_target, get_mean_fm_target, get_gmean_fm_target]
eva_name = ['Min', 'Max', 'Mean', 'GMean']
print(len(FM_by_dim))
for fm, dim in zip(FM_by_dim, dim_list):
    print('test')
    for eva_idx, fm_target in enumerate(fm_targets):
        print(eva_idx)
        
        eva_fm = fm[eva_idx, :, :, :]
        eva_fm_mean = np.mean(eva_fm, 1)
        eva_fm_min = np.amin(eva_fm, 1)
        eva_fm_max = np.amax(eva_fm, 1)
        
        # First set up the figure, the axis, and the plot element we want to animate
        fig = plt.figure()
        ax = plt.axes(xlim=(0, np.size(eva_fm, 2)-1), ylim=(-0.1, 1.1))
        line, = ax.plot([], [], lw=2)
        ax.set_xlabel('Fuzzy Measure Value')
        ax.set_ylabel('Fuzzy Measure Count')
        ax.set_title(eva_name[eva_idx]+' Dim='+str(dim))
        
        # initialization function: plot the background of each frame
        def init():
            x = list(range(np.size(eva_fm, 2)))
            y = list(fm_target(dim).values())
            plt.plot(x, y)
            ax.legend(['FM Target'], loc=4)
            

        # animation function.  This is called sequentially
        def animate(i):
            x = np.asarray(list(range(np.size(eva_fm, 2))))
            y = eva_fm_mean[i, :]
            line.set_data(x, y)
            ax.legend(['FM Predict (Seen data percentage ' + str("{0:.0%}".format((i+1)/np.size(eva_fm, 0))) + ')', 'FM Target'])
            ax.collections = []
            plt.fill_between(x, eva_fm_min[i, :], eva_fm_max[i, :], color='blue', alpha=0.1)
            return line,

        # call the animator.  blit=True means only re-draw the parts that have changed.
        anim = animation.FuncAnimation(fig, animate, frames=np.size(eva_fm, 0), init_func=init(), interval=200, blit=True)

        # save the animation as an mp4.  This requires ffmpeg or mencoder to be
        # installed.  The extra_args ensure that the x264 codec is used, so that
        # the video can be embedded in html5.  You may need to adjust this for
        # your system: for more information, see
        # http://matplotlib.sourceforge.net/api/animation_api.html
        anim.save(output_dir + 'ChINN-' + str(dim) + '-' + eva_name[eva_idx] + 'FM.mp4', fps=1)
#         plt.show()
        plt.close('all')

test
1
test
0
1
2
3
