### Parameter setting
#### * Adjustments are required according to the trained model *

In [1]:
get_model_directory = lambda series, cut: f"models/{series}_series/{cut}_cut/"
get_model_name = lambda number:  f"m_{number}.h5"
simplexDir = lambda series, cut, number: f"simplexes/{series}_series/{cut}_cut/simp_{number}/"

filList = range(1,65)

In [2]:
from keras import models
import numpy as np
import copy
import itertools 
import pickle

### Construct simplex

In [3]:
def get_relevance(model, outputSize = 1, input_layer=True):
    if input_layer:
        layers = model.layers
    else:
        layers = model.layers[1:]
        
    weights = [layer.get_weights()[0] for layer in layers]
    sizes = [len(weight) for weight in weights] + [outputSize]

    offset = 0
    size = sum(sizes)
    relevance = np.identity(size)

    for layer_num in range(len(sizes) - 1, 0, -1):
        old_offset = offset
        offset += sizes[layer_num]

        weight = weights[layer_num - 1]
        weightPlus = weight * (weight > 0)
        for j in range(0, sizes[layer_num]):
            normalizeFactor = 0
            for i in range(sizes[layer_num - 1]):
                normalizeFactor += weightPlus[i][j]
            for i in range(sizes[layer_num - 1]):
                x, y = i + offset, j + old_offset
                if weightPlus[i][j] != 0:
                    relevance[x][y] = weightPlus[i][j] / normalizeFactor
    return np.array(relevance) 

In [4]:
def comb(sequence):
    result = []
    for L in range(1, len(sequence)+1):
        for subset in itertools.combinations(sequence, L):
            result.append(list(subset))
    return result

In [5]:
def getSimplex(matrix, pointSequence, threshold):
    matrixSize = len(matrix)
    
    relevance = 1.0
    result = []
    #startPointからのRelevanceを計算する
    startPoint = pointSequence[0]
    for pointNumber in pointSequence:
        relevance = relevance * matrix[startPoint][pointNumber]
        startPoint = pointNumber
    #relevanceがthreshold以上だったらここまでの経路を追加する

    if relevance >= threshold:
        for e in comb(pointSequence):
            result.append(e)
        #最後の要素からの連結要素について再帰的にチェックする
        lastPoint = pointSequence[-1]
        for i in range(matrixSize):
            if matrix[lastPoint][i] > 0 and i != lastPoint:
                tempPointSequence = copy.deepcopy(pointSequence)
                tempPointSequence.append(i)
                #再帰呼び出し
                temp = getSimplex(matrix, tempPointSequence, threshold)
                #結果をresultに追加
                for e in temp:
                    for ee in comb(e):
                        result.append(ee)
    return list(map(list, set(map(tuple,result))))

In [6]:
import os
from tqdm import tqdm 
r = list(reversed(np.logspace(-7, 0, base=10, num=64)))

def registerSimplexOutput(filList, series, cut, id, name=None):
    if name == None:
        model = models.load_model(get_model_directory(series, cut) + get_model_name(id))
    else: 
        model = models.load_model(name)
    matrix = get_relevance(model, input_layer = True)
    matrixSize = len(matrix)
    
    for fil in tqdm(filList):
        number = r[fil - 1]
        if name == None:
            filename =  simplexDir(series, cut, id) + "Simplex" + str(fil)
        else:
            filename = "random_simp" + name + "/Simplex" + str(fil)
        
        saveSimplex = []
        for startPoint in range(0, matrixSize):
            simplex = getSimplex(matrix, [startPoint], number)
            saveSimplex.extend(simplex)
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        saveFile = open(filename, 'wb')
        pickle.dump(saveSimplex, saveFile)
        saveFile.close

[1.0, 0.7742636826811262, 0.5994842503189409, 0.46415888336127725, 0.3593813663804626, 0.27825594022071204, 0.21544346900318823, 0.16681005372000557, 0.12915496650148828, 0.1, 0.07742636826811262, 0.05994842503189409, 0.046415888336127725, 0.03593813663804626, 0.0278255940220712, 0.021544346900318822, 0.016681005372000592, 0.012915496650148827, 0.01, 0.007742636826811262, 0.005994842503189409, 0.004641588833612773, 0.003593813663804626, 0.0027825594022071257, 0.002154434690031882, 0.0016681005372000592, 0.0012915496650148827, 0.001, 0.000774263682681127, 0.0005994842503189409, 0.00046415888336127773, 0.00035938136638046257, 0.0002782559402207123, 0.00021544346900318823, 0.00016681005372000575, 0.00012915496650148828, 0.0001, 7.742636826811278e-05, 5.994842503189409e-05, 4.641588833612772e-05, 3.5938136638046256e-05, 2.782559402207126e-05, 2.1544346900318823e-05, 1.6681005372000593e-05, 1.2915496650148827e-05, 1e-05, 7.742636826811277e-06, 5.994842503189409e-06, 4.641588833612773e-06, 3

In [8]:
log_cuts = [1, 2, 3, 5, 10, 20, 40, 60, 100, 300]
tries = 15
series = [12]

In [9]:
import datetime 

def log_preamble(log_file_name):
    global tries, epochs, series
    time_stamp = datetime.datetime.now()

    with open(log_file_name, "a") as log_file:
        log_file.write(f"\n\nExecuted on time is {datetime.datetime.now()}\n")
        log_file.write(f"Tries: {tries}, series = {series}\n")

    return time_stamp

def log_final(log_file_name, time_stamp):
    time_stamp_new = datetime.datetime.now()
    with open(log_file_name, "a") as log_file:
        log_file.write(f"Finished successfully at {time_stamp_new}\n")
        log_file.write(f"Total time = {(time_stamp_new - time_stamp).total_seconds()}\n")

In [11]:
log_file_name = "simplex_log.txt"

t = log_preamble(log_file_name)

for s in series:
    for cut in log_cuts:
        for x in range(tries):
            registerSimplexOutput(filList, s, cut, x)

log_final(log_file_name, t)

100%|██████████| 64/64 [00:00<00:00, 155.34it/s]
100%|██████████| 64/64 [00:00<00:00, 145.12it/s]
100%|██████████| 64/64 [00:00<00:00, 86.37it/s] 
100%|██████████| 64/64 [00:00<00:00, 86.02it/s] 
100%|██████████| 64/64 [00:00<00:00, 111.30it/s]
100%|██████████| 64/64 [00:02<00:00, 24.33it/s] 
100%|██████████| 64/64 [00:00<00:00, 68.09it/s] 
100%|██████████| 64/64 [00:01<00:00, 59.37it/s] 
100%|██████████| 64/64 [00:01<00:00, 58.08it/s] 
100%|██████████| 64/64 [00:01<00:00, 54.28it/s] 
100%|██████████| 64/64 [00:00<00:00, 87.55it/s] 
100%|██████████| 64/64 [00:01<00:00, 55.99it/s] 
100%|██████████| 64/64 [00:01<00:00, 63.68it/s] 
100%|██████████| 64/64 [00:00<00:00, 84.10it/s] 
100%|██████████| 64/64 [00:01<00:00, 59.04it/s] 
100%|██████████| 64/64 [00:00<00:00, 78.62it/s] 
100%|██████████| 64/64 [00:01<00:00, 58.66it/s] 
100%|██████████| 64/64 [00:00<00:00, 72.64it/s] 
100%|██████████| 64/64 [00:01<00:00, 55.03it/s] 
100%|██████████| 64/64 [00:00<00:00, 73.65it/s] 
100%|██████████| 64/