### Parameter setting
#### * Adjustments are required according to the trained model *

In [1]:
get_model_directory = lambda series, cut: "models/{}_series/{}_cut/".format(series, cut)
get_model_name = lambda number:  "m_{}.h5".format(number)
simplexDir = lambda series, cut, number: "simplexes/{}_series/{}_cut/simp_{}/".format(series, cut, number)

# the list of filtration. Parallel process execution is possible by distributing this list.
filList = range(1,65)

In [2]:
from keras import models;
import numpy as np
import pandas as pd
import copy
import itertools 
import pickle

### Construct simplex

In [3]:
def get_relevance(model, outputSize = 1, input_layer=True):
    if input_layer:
        layers = model.layers
    else:
        layers = model.layers[1:]
        
    weights = [layer.get_weights()[0] for layer in layers]
    sizes = [len(weight) for weight in weights] + [outputSize]

    offset = 0
    size = sum(sizes)
    relevance = np.identity(size)

    total_params = 0
    for layer_num in range(len(sizes) - 1, 0, -1):
        old_offset = offset
        offset += sizes[layer_num]

        weight = weights[layer_num - 1]
        total_params += weight.shape[0] * weight.shape[1]
        weightPlus = weight * (weight > 0)
        
        for j in range(0, sizes[layer_num]):
            normalizeFactor = 0
            for i in range(sizes[layer_num - 1]):
                normalizeFactor += weightPlus[i][j]
            for i in range(sizes[layer_num - 1]):
                x, y = i + offset, j + old_offset
                if weightPlus[i][j] != 0:
                    relevance[x][y] = weightPlus[i][j] / normalizeFactor
    return np.array(relevance) 

In [4]:
def comb(sequence):
    result = []
    for L in range(1, len(sequence)+1):
        for subset in itertools.combinations(sequence, L):
            result.append(list(subset))
    return result

In [5]:
def getSimplex(matrix, pointSequence, threshold):
    matrixSize = len(matrix)
    
    relevance = 1.0
    result = []
    #startPointからのRelevanceを計算する
    startPoint = pointSequence[0]
    for pointNumber in pointSequence:
        relevance = relevance * matrix[startPoint][pointNumber]
        startPoint = pointNumber
    #relevanceがthreshold以上だったらここまでの経路を追加する

    if relevance >= threshold:
        for e in comb(pointSequence):
            result.append(e)
        #最後の要素からの連結要素について再帰的にチェックする
        lastPoint = pointSequence[-1]
        for i in range(matrixSize):
            if matrix[lastPoint][i] > 0 and i != lastPoint:
                tempPointSequence = copy.deepcopy(pointSequence)
                tempPointSequence.append(i)
                #再帰呼び出し
                temp = getSimplex(matrix, tempPointSequence, threshold)
                #結果をresultに追加
                for e in temp:
                    for ee in comb(e):
                        result.append(ee)
    return list(map(list, set(map(tuple,result))))

In [6]:
import os

r = list(reversed(np.logspace(-7, 0, base=10, num=64)))

def registerSimplexOutput(filList, series, cut, id, name=None):
    if name == None:
        model = models.load_model(get_model_directory(series, cut) + get_model_name(id))
    else: 
        model = models.load_model(name)
    matrix = get_relevance(model, input_layer = True)
    matrixSize = len(matrix)
    
    print("Filtration: ", end = "")
    for fil in filList:
        number = r[fil - 1]
        if name == None:
            filename =  simplexDir(series, cut, id) + "Simplex" + str(fil)
        else:
            filename = "random_simp" + name + "/Simplex" + str(fil)
        
        print( str(fil) + ", ", end="")

        saveSimplex = []
        for startPoint in range(0, matrixSize):
            simplex = getSimplex(matrix, [startPoint], number)
            saveSimplex.extend(simplex)
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        saveFile = open(filename, 'wb')
        pickle.dump(saveSimplex, saveFile)
        saveFile.close

In [7]:
num_data = pd.read_csv("data.csv")
volume = len(num_data)

data_vol = [volume, volume/2, 100, 60, 30, 15]
log_cuts = list(map(lambda x: int(volume / x), data_vol))

In [8]:
data_vol

[1941, 970.5, 100, 60, 30, 15]

In [9]:
log_cuts

[1, 2, 19, 32, 64, 129]

In [10]:
tries = 10
series = [0, 1]

In [11]:
# for s in series:
#     for cut in log_cuts:
#         for x in range(tries):
#             registerSimplexOutput(filList, s, cut, x)

In [12]:
registerSimplexOutput(filList, 1, 129, 9)

Filtration: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 